From 1e90a1bf36d57ef9b9919a61d198c2d0a25d512e Mon Sep 17 00:00:00 2001 From: Jin Hai Date: Fri, 15 Nov 2024 17:30:56 +0800 Subject: [PATCH] Move settings initialization after module init phase (#3438) ### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai --- agent/component/generate.py | 22 +-- agent/component/retrieval.py | 4 +- api/apps/__init__.py | 10 +- api/apps/api_app.py | 64 ++++---- api/apps/canvas_app.py | 22 +-- api/apps/chunk_app.py | 34 ++-- api/apps/conversation_app.py | 13 +- api/apps/dialog_app.py | 4 +- api/apps/document_app.py | 61 +++---- api/apps/file2document_app.py | 4 +- api/apps/file_app.py | 10 +- api/apps/kb_app.py | 15 +- api/apps/llm_app.py | 4 +- api/apps/sdk/chat.py | 6 +- api/apps/sdk/dataset.py | 6 +- api/apps/sdk/dify_retrieval.py | 12 +- api/apps/sdk/doc.py | 50 +++--- api/apps/system_app.py | 9 +- api/apps/user_app.py | 71 ++++---- api/db/db_models.py | 14 +- api/db/init_data.py | 23 +-- api/db/services/dialog_service.py | 8 +- api/db/services/document_service.py | 10 +- api/ragflow_server.py | 11 +- api/settings.py | 241 ++++++++++++++++------------ api/utils/api_utils.py | 50 +++--- deepdoc/parser/pdf_parser.py | 4 +- graphrag/claim_extractor.py | 4 +- graphrag/smoke.py | 4 +- rag/benchmark.py | 22 +-- rag/llm/embedding_model.py | 8 +- rag/llm/rerank_model.py | 6 +- rag/svr/task_executor.py | 37 +++-- 33 files changed, 452 insertions(+), 411 deletions(-) diff --git a/agent/component/generate.py b/agent/component/generate.py index d4bd67fa0..b84d2fef0 100644 --- a/agent/component/generate.py +++ b/agent/component/generate.py @@ -19,7 +19,7 @@ import pandas as pd from api.db import LLMType from api.db.services.dialog_service import message_fit_in from api.db.services.llm_service import LLMBundle -from api.settings import retrievaler +from api import settings from agent.component.base import ComponentBase, ComponentParamBase @@ -63,18 +63,20 @@ class Generate(ComponentBase): component_name = "Generate" def get_dependent_components(self): - cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0] + cpnts = [para["component_id"] for para in self._param.parameters if + para.get("component_id") and para["component_id"].lower().find("answer") < 0] return cpnts def set_cite(self, retrieval_res, answer): retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) if "empty_response" in retrieval_res.columns: retrieval_res["empty_response"].fillna("", inplace=True) - answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], - [ck["vector"] for _, ck in retrieval_res.iterrows()], - LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, - self._canvas.get_embedding_model()), tkweight=0.7, - vtweight=0.3) + answer, idx = settings.retrievaler.insert_citations(answer, + [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], + [ck["vector"] for _, ck in retrieval_res.iterrows()], + LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, + self._canvas.get_embedding_model()), tkweight=0.7, + vtweight=0.3) doc_ids = set([]) recall_docs = [] for i in idx: @@ -127,12 +129,14 @@ class Generate(ComponentBase): else: if cpn.component_name.lower() == "retrieval": retrieval_res.append(out) - kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]]) + kwargs[para["key"]] = " - " + "\n - ".join( + [o if isinstance(o, str) else str(o) for o in out["content"]]) self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) if retrieval_res: retrieval_res = pd.concat(retrieval_res, ignore_index=True) - else: retrieval_res = pd.DataFrame([]) + else: + retrieval_res = pd.DataFrame([]) for n, v in kwargs.items(): prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index ab48e04a7..f9f3e9fa2 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -21,7 +21,7 @@ import pandas as pd from api.db import LLMType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.settings import retrievaler +from api import settings from agent.component.base import ComponentBase, ComponentParamBase @@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC): if self._param.rerank_id: rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) - kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, + kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, 1, self._param.top_n, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, aggs=False, rerank_mdl=rerank_mdl) diff --git a/api/apps/__init__.py b/api/apps/__init__.py index 73f323a14..5bc8d327f 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands from flask_session import Session from flask_login import LoginManager -from api.settings import SECRET_KEY -from api.settings import API_VERSION +from api import settings from api.utils.api_utils import server_error_response from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer @@ -78,7 +77,6 @@ app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) - ## convince for dev and debug # app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False @@ -110,7 +108,7 @@ def register_page(page_path): page_name = page_path.stem.rstrip("_app") module_name = ".".join( - page_path.parts[page_path.parts.index("api") : -1] + (page_name,) + page_path.parts[page_path.parts.index("api"): -1] + (page_name,) ) spec = spec_from_file_location(module_name, page_path) @@ -121,7 +119,7 @@ def register_page(page_path): spec.loader.exec_module(page) page_name = getattr(page, "page_name", page_name) url_prefix = ( - f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}" + f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}" ) app.register_blueprint(page.manager, url_prefix=url_prefix) @@ -141,7 +139,7 @@ client_urls_prefix = [ @login_manager.request_loader def load_user(web_request): - jwt = Serializer(secret_key=SECRET_KEY) + jwt = Serializer(secret_key=settings.SECRET_KEY) authorization = web_request.headers.get("Authorization") if authorization: try: diff --git a/api/apps/api_app.py b/api/apps/api_app.py index b21f129d2..d77a437aa 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -32,7 +32,7 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService -from api.settings import RetCode, retrievaler +from api import settings from api.utils import get_uuid, current_timestamp, datetime_format from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ generate_confirmation_token @@ -141,7 +141,7 @@ def set_conversation(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) req = request.json try: if objs[0].source == "agent": @@ -183,7 +183,7 @@ def completion(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) req = request.json e, conv = API4ConversationService.get_by_id(req["conversation_id"]) if not e: @@ -290,8 +290,8 @@ def completion(): API4ConversationService.append_message(conv.id, conv.to_dict()) rename_field(result) return get_json_result(data=result) - - #******************For dialog****************** + + # ******************For dialog****************** conv.message.append(msg[-1]) e, dia = DialogService.get_by_id(conv.dialog_id) if not e: @@ -326,7 +326,7 @@ def completion(): resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - + answer = None for ans in chat(dia, msg, **req): answer = ans @@ -347,8 +347,8 @@ def get(conversation_id): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) - + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) + try: e, conv = API4ConversationService.get_by_id(conversation_id) if not e: @@ -357,8 +357,8 @@ def get(conversation_id): conv = conv.to_dict() if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: return get_json_result(data=False, message='Token is not valid for this conversation_id!"', - code=RetCode.AUTHENTICATION_ERROR) - + code=settings.RetCode.AUTHENTICATION_ERROR) + for referenct_i in conv['reference']: if referenct_i is None or len(referenct_i) == 0: continue @@ -378,7 +378,7 @@ def upload(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) kb_name = request.form.get("kb_name").strip() tenant_id = objs[0].tenant_id @@ -394,12 +394,12 @@ def upload(): if 'file' not in request.files: return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) file = request.files['file'] if file.filename == '': return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] @@ -490,17 +490,17 @@ def upload_parse(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) if 'file' not in request.files: return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) file_objs = request.files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) return get_json_result(data=doc_ids) @@ -513,7 +513,7 @@ def list_chunks(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) req = request.json @@ -531,7 +531,7 @@ def list_chunks(): ) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids) + res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids) res = [ { "content": res_item["content_with_weight"], @@ -553,7 +553,7 @@ def list_kb_docs(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) req = request.json tenant_id = objs[0].tenant_id @@ -585,6 +585,7 @@ def list_kb_docs(): except Exception as e: return server_error_response(e) + @manager.route('/document/infos', methods=['POST']) @validate_request("doc_ids") def docinfos(): @@ -592,7 +593,7 @@ def docinfos(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) req = request.json doc_ids = req["doc_ids"] docs = DocumentService.get_by_ids(doc_ids) @@ -606,7 +607,7 @@ def document_rm(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) tenant_id = objs[0].tenant_id req = request.json @@ -653,7 +654,7 @@ def document_rm(): errors += str(e) if errors: - return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) + return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR) return get_json_result(data=True) @@ -668,7 +669,7 @@ def completion_faq(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) e, conv = API4ConversationService.get_by_id(req["conversation_id"]) if not e: @@ -805,10 +806,10 @@ def retrieval(): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR) req = request.json - kb_ids = req.get("kb_id",[]) + kb_ids = req.get("kb_id", []) doc_ids = req.get("doc_ids", []) question = req.get("question") page = int(req.get("page", 1)) @@ -822,20 +823,21 @@ def retrieval(): embd_nms = list(set([kb.embd_id for kb in kbs])) if len(embd_nms) != 1: return get_json_result( - data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR) + data=False, message='Knowledge bases use different embedding models or does not exist."', + code=settings.RetCode.AUTHENTICATION_ERROR) embd_mdl = TenantLLMService.model_instance( kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) rerank_mdl = None if req.get("rerank_id"): rerank_mdl = TenantLLMService.model_instance( - kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) + kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) if req.get("keyword", False): chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl) + ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, + similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl) for c in ranks["chunks"]: if "vector" in c: del c["vector"] @@ -843,5 +845,5 @@ def retrieval(): except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', - code=RetCode.DATA_ERROR) + code=settings.RetCode.DATA_ERROR) return server_error_response(e) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 751202f8a..a9d588432 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -19,7 +19,7 @@ from functools import partial from flask import request, Response from flask_login import login_required, current_user from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService -from api.settings import RetCode +from api import settings from api.utils import get_uuid from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from agent.canvas import Canvas @@ -36,7 +36,8 @@ def templates(): @login_required def canvas_list(): return get_json_result(data=sorted([c.to_dict() for c in \ - UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1) + UserCanvasService.query(user_id=current_user.id)], + key=lambda x: x["update_time"] * -1) ) @@ -45,10 +46,10 @@ def canvas_list(): @login_required def rm(): for i in request.json["canvas_ids"]: - if not UserCanvasService.query(user_id=current_user.id,id=i): + if not UserCanvasService.query(user_id=current_user.id, id=i): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) UserCanvasService.delete_by_id(i) return get_json_result(data=True) @@ -72,7 +73,7 @@ def save(): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) UserCanvasService.update_by_id(req["id"], req) return get_json_result(data=req) @@ -98,7 +99,7 @@ def run(): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) @@ -110,8 +111,8 @@ def run(): if "message" in req: canvas.messages.append({"role": "user", "content": req["message"], "id": message_id}) if len([m for m in canvas.messages if m["role"] == "user"]) > 1: - #ten = TenantService.get_info_by(current_user.id)[0] - #req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) + # ten = TenantService.get_info_by(current_user.id)[0] + # req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages) pass canvas.add_user_input(req["message"]) answer = canvas.run(stream=stream) @@ -122,7 +123,8 @@ def run(): assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." if stream: - assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." + assert isinstance(answer, + partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." def sse(): nonlocal answer, cvs @@ -173,7 +175,7 @@ def reset(): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) canvas.reset() diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 61a9622f9..252f0fa40 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle from api.db.services.user_service import UserTenantService from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db.services.document_service import DocumentService -from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn +from api import settings from api.utils.api_utils import get_json_result import hashlib import re + @manager.route('/list', methods=['POST']) @login_required @validate_request("doc_id") @@ -56,7 +57,7 @@ def list_chunk(): } if "available_int" in req: query["available_int"] = int(req["available_int"]) - sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) + sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} for id in sres.ids: d = { @@ -72,13 +73,13 @@ def list_chunk(): "positions": json.loads(sres.field[id].get("position_list", "[]")), } assert isinstance(d["positions"], list) - assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) + assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) res["chunks"].append(d) return get_json_result(data=res) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found!', - code=RetCode.DATA_ERROR) + code=settings.RetCode.DATA_ERROR) return server_error_response(e) @@ -93,7 +94,7 @@ def get(): tenant_id = tenants[0].tenant_id kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) if chunk is None: return server_error_response("Chunk not found") k = [] @@ -107,7 +108,7 @@ def get(): except Exception as e: if str(e).find("NotFoundError") >= 0: return get_json_result(data=False, message='Chunk not found!', - code=RetCode.DATA_ERROR) + code=settings.RetCode.DATA_ERROR) return server_error_response(e) @@ -154,7 +155,7 @@ def set(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -169,8 +170,8 @@ def switch(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, - search.index_name(doc.tenant_id), doc.kb_id): + if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, + search.index_name(doc.tenant_id), doc.kb_id): return get_data_error_result(message="Index updating failure") return get_json_result(data=True) except Exception as e: @@ -186,7 +187,7 @@ def rm(): e, doc = DocumentService.get_by_id(req["doc_id"]) if not e: return get_data_error_result(message="Document not found!") - if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): + if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): return get_data_error_result(message="Index updating failure") deleted_chunk_ids = req["chunk_ids"] chunk_number = len(deleted_chunk_ids) @@ -230,7 +231,7 @@ def create(): v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) DocumentService.increment_chunk_num( doc.id, doc.kb_id, c, 1, 0) @@ -265,7 +266,7 @@ def retrieval_test(): else: return get_json_result( data=False, message='Only owner of knowledgebase authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: @@ -281,7 +282,7 @@ def retrieval_test(): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler + retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) @@ -293,7 +294,7 @@ def retrieval_test(): except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', - code=RetCode.DATA_ERROR) + code=settings.RetCode.DATA_ERROR) return server_error_response(e) @@ -304,10 +305,10 @@ def knowledge_graph(): tenant_id = DocumentService.get_tenant_id(doc_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) req = { - "doc_ids":[doc_id], + "doc_ids": [doc_id], "knowledge_graph_kwd": ["graph", "mind_map"] } - sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids) + sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids) obj = {"graph": {}, "mind_map": {}} for id in sres.ids[:2]: ty = sres.field[id]["knowledge_graph_kwd"] @@ -336,4 +337,3 @@ def knowledge_graph(): obj[ty] = content_json return get_json_result(data=obj) - diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index a86a0dbae..5da16dc03 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -25,7 +25,7 @@ from api.db import LLMType from api.db.services.dialog_service import DialogService, ConversationService, chat, ask from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService -from api.settings import RetCode, retrievaler +from api import settings from api.utils.api_utils import get_json_result from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from graphrag.mind_map_extractor import MindMapExtractor @@ -87,7 +87,7 @@ def get(): else: return get_json_result( data=False, message='Only owner of conversation authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) conv = conv.to_dict() return get_json_result(data=conv) except Exception as e: @@ -110,7 +110,7 @@ def rm(): else: return get_json_result( data=False, message='Only owner of conversation authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) ConversationService.delete_by_id(cid) return get_json_result(data=True) except Exception as e: @@ -125,7 +125,7 @@ def list_convsersation(): if not DialogService.query(tenant_id=current_user.id, id=dialog_id): return get_json_result( data=False, message='Only owner of dialog authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) convs = ConversationService.query( dialog_id=dialog_id, order_by=ConversationService.model.create_time, @@ -297,6 +297,7 @@ def thumbup(): def ask_about(): req = request.json uid = current_user.id + def stream(): nonlocal req, uid try: @@ -329,8 +330,8 @@ def mindmap(): embd_mdl = TenantLLMService.model_instance( kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) - ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, - 0.3, 0.3, aggs=False) + ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, + 0.3, 0.3, aggs=False) mindmap = MindMapExtractor(chat_mdl) mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output if "error" in mind_map: diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index e3c87a1fc..ebab025a1 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService from api.db import StatusEnum from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService, UserTenantService -from api.settings import RetCode +from api import settings from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result @@ -175,7 +175,7 @@ def rm(): else: return get_json_result( data=False, message='Only owner of dialog authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) DialogService.update_many_by_id(dialog_list) return get_json_result(data=True) diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 6a2362a90..f705ac8a8 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va from api.utils import get_uuid from api.db import FileType, TaskStatus, ParserType, FileSource from api.db.services.document_service import DocumentService, doc_upload_and_parse -from api.settings import RetCode, docStoreConn +from api import settings from api.utils.api_utils import get_json_result from rag.utils.storage_factory import STORAGE_IMPL from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory @@ -49,16 +49,16 @@ def upload(): kb_id = request.form.get("kb_id") if not kb_id: return get_json_result( - data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) + data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) if 'file' not in request.files: return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) file_objs = request.files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: @@ -67,7 +67,7 @@ def upload(): err, _ = FileService.upload_document(kb, file_objs, current_user.id) if err: return get_json_result( - data=False, message="\n".join(err), code=RetCode.SERVER_ERROR) + data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR) return get_json_result(data=True) @@ -78,12 +78,12 @@ def web_crawl(): kb_id = request.form.get("kb_id") if not kb_id: return get_json_result( - data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) + data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) name = request.form.get("name") url = request.form.get("url") if not is_valid_url(url): return get_json_result( - data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) + data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: raise LookupError("Can't find this knowledgebase!") @@ -145,7 +145,7 @@ def create(): kb_id = req["kb_id"] if not kb_id: return get_json_result( - data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) + data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) try: e, kb = KnowledgebaseService.get_by_id(kb_id) @@ -179,7 +179,7 @@ def list_docs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result( - data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) + data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR) tenants = UserTenantService.query(user_id=current_user.id) for tenant in tenants: if KnowledgebaseService.query( @@ -188,7 +188,7 @@ def list_docs(): else: return get_json_result( data=False, message='Only owner of knowledgebase authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 1)) @@ -218,19 +218,19 @@ def docinfos(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) docs = DocumentService.get_by_ids(doc_ids) return get_json_result(data=list(docs.dicts())) @manager.route('/thumbnails', methods=['GET']) -#@login_required +# @login_required def thumbnails(): doc_ids = request.args.get("doc_ids").split(",") if not doc_ids: return get_json_result( - data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) + data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR) try: docs = DocumentService.get_thumbnails(doc_ids) @@ -253,13 +253,13 @@ def change_status(): return get_json_result( data=False, message='"Status" must be either 0 or 1!', - code=RetCode.ARGUMENT_ERROR) + code=settings.RetCode.ARGUMENT_ERROR) if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR) + code=settings.RetCode.AUTHENTICATION_ERROR) try: e, doc = DocumentService.get_by_id(req["doc_id"]) @@ -276,7 +276,8 @@ def change_status(): message="Database error (Document update)!") status = int(req["status"]) - docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) + settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, + search.index_name(kb.tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: return server_error_response(e) @@ -295,7 +296,7 @@ def rm(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) root_folder = FileService.get_root_folder(current_user.id) @@ -326,7 +327,7 @@ def rm(): errors += str(e) if errors: - return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) + return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR) return get_json_result(data=True) @@ -341,7 +342,7 @@ def run(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) try: for id in req["doc_ids"]: @@ -358,8 +359,8 @@ def run(): e, doc = DocumentService.get_by_id(id) if not e: return get_data_error_result(message="Document not found!") - if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) if str(req["run"]) == TaskStatus.RUNNING.value: TaskService.filter_delete([Task.doc_id == id]) @@ -383,7 +384,7 @@ def rename(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) try: e, doc = DocumentService.get_by_id(req["doc_id"]) @@ -394,7 +395,7 @@ def rename(): return get_json_result( data=False, message="The extension of file can't be changed", - code=RetCode.ARGUMENT_ERROR) + code=settings.RetCode.ARGUMENT_ERROR) for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: return get_data_error_result( @@ -450,7 +451,7 @@ def change_parser(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) try: e, doc = DocumentService.get_by_id(req["doc_id"]) @@ -483,8 +484,8 @@ def change_parser(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(message="Tenant not found!") - if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) return get_json_result(data=True) except Exception as e: @@ -509,13 +510,13 @@ def get_image(image_id): def upload_and_parse(): if 'file' not in request.files: return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) file_objs = request.files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) @@ -529,7 +530,7 @@ def parse(): if url: if not is_valid_url(url): return get_json_result( - data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) + data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR) download_path = os.path.join(get_project_base_directory(), "logs/downloads") os.makedirs(download_path, exist_ok=True) from selenium.webdriver import Chrome, ChromeOptions @@ -553,7 +554,7 @@ def parse(): if 'file' not in request.files: return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) file_objs = request.files.getlist('file') txt = FileService.parse_docs(file_objs, current_user.id) diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index 97209b780..1eb797e89 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va from api.utils import get_uuid from api.db import FileType from api.db.services.document_service import DocumentService -from api.settings import RetCode +from api import settings from api.utils.api_utils import get_json_result @@ -100,7 +100,7 @@ def rm(): file_ids = req["file_ids"] if not file_ids: return get_json_result( - data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR) + data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR) try: for file_id in file_ids: informs = File2DocumentService.get_by_file_id(file_id) diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 318a98fc7..bf112ab3d 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -28,7 +28,7 @@ from api.utils import get_uuid from api.db import FileType, FileSource from api.db.services import duplicate_name from api.db.services.file_service import FileService -from api.settings import RetCode +from api import settings from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type from rag.utils.storage_factory import STORAGE_IMPL @@ -46,13 +46,13 @@ def upload(): if 'file' not in request.files: return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR) file_objs = request.files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) + data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR) file_res = [] try: for file_obj in file_objs: @@ -134,7 +134,7 @@ def create(): try: if not FileService.is_parent_folder_exist(pf_id): return get_json_result( - data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR) + data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR) if FileService.query(name=req["name"], parent_id=pf_id): return get_data_error_result( message="Duplicated folder name in the same folder.") @@ -299,7 +299,7 @@ def rename(): return get_json_result( data=False, message="The extension of file can't be changed", - code=RetCode.ARGUMENT_ERROR) + code=settings.RetCode.ARGUMENT_ERROR) for file in FileService.query(name=req["name"], pf_id=file.parent_id): if file.name == req["name"]: return get_data_error_result( diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 8bbd1c7ca..f98fe51a0 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -26,9 +26,8 @@ from api.utils import get_uuid from api.db import StatusEnum, FileSource from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import File -from api.settings import RetCode from api.utils.api_utils import get_json_result -from api.settings import docStoreConn +from api import settings from rag.nlp import search @@ -68,13 +67,13 @@ def update(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) try: if not KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]): return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) + data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) if not e: @@ -113,7 +112,7 @@ def detail(): else: return get_json_result( data=False, message='Only owner of knowledgebase authorized for this operation.', - code=RetCode.OPERATING_ERROR) + code=settings.RetCode.OPERATING_ERROR) kb = KnowledgebaseService.get_detail(kb_id) if not kb: return get_data_error_result( @@ -148,14 +147,14 @@ def rm(): return get_json_result( data=False, message='No authorization.', - code=RetCode.AUTHENTICATION_ERROR + code=settings.RetCode.AUTHENTICATION_ERROR ) try: kbs = KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]) if not kbs: return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) + data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR) for doc in DocumentService.query(kb_id=req["kb_id"]): if not DocumentService.remove_document(doc, kbs[0].tenant_id): @@ -170,7 +169,7 @@ def rm(): message="Database error (Knowledgebase removal)!") tenants = UserTenantService.query(user_id=current_user.id) for tenant in tenants: - docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) + settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) return get_json_result(data=True) except Exception as e: return server_error_response(e) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 487c70e2b..9b5901765 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -19,7 +19,7 @@ import json from flask import request from flask_login import login_required, current_user from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService -from api.settings import LIGHTEN +from api import settings from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.db import StatusEnum, LLMType from api.db.db_models import TenantLLM @@ -333,7 +333,7 @@ def my_llms(): @login_required def list_app(): self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] - weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else [] + weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else [] model_type = request.args.get("model_type") try: objs = TenantLLMService.query(tenant_id=current_user.id) diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index aad48cb59..0081077ce 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -14,7 +14,7 @@ # limitations under the License. # from flask import request -from api.settings import RetCode +from api import settings from api.db import StatusEnum from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -44,7 +44,7 @@ def create(tenant_id): kbs = KnowledgebaseService.get_by_ids(ids) embd_count = list(set([kb.embd_id for kb in kbs])) if len(embd_count) != 1: - return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR) + return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR) req["kb_ids"] = ids # llm llm = req.get("llm") @@ -173,7 +173,7 @@ def update(tenant_id,chat_id): if len(embd_count) != 1 : return get_result( message='Datasets use different embedding models."', - code=RetCode.AUTHENTICATION_ERROR) + code=settings.RetCode.AUTHENTICATION_ERROR) req["kb_ids"] = ids llm = req.get("llm") if llm: diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 047ab8279..61cc3dd84 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -23,7 +23,7 @@ from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService, LLMService from api.db.services.user_service import TenantService -from api.settings import RetCode +from api import settings from api.utils import get_uuid from api.utils.api_utils import ( get_result, @@ -255,7 +255,7 @@ def delete(tenant_id): File2DocumentService.delete_by_document_id(doc.id) if not KnowledgebaseService.delete_by_id(id): return get_error_data_result(message="Delete dataset error.(Database error)") - return get_result(code=RetCode.SUCCESS) + return get_result(code=settings.RetCode.SUCCESS) @manager.route("/datasets/", methods=["PUT"]) @@ -424,7 +424,7 @@ def update(tenant_id, dataset_id): ) if not KnowledgebaseService.update_by_id(kb.id, req): return get_error_data_result(message="Update dataset error.(Database error)") - return get_result(code=RetCode.SUCCESS) + return get_result(code=settings.RetCode.SUCCESS) @manager.route("/datasets", methods=["GET"]) diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index d2e6ee6f1..8e13f5c6d 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -18,7 +18,7 @@ from flask import request, jsonify from api.db import LLMType, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.settings import retrievaler, kg_retrievaler, RetCode +from api import settings from api.utils.api_utils import validate_request, build_error_result, apikey_required @@ -37,14 +37,14 @@ def retrieval(tenant_id): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: - return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) + return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) if kb.tenant_id != tenant_id: - return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) + return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) - retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler + retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler ranks = retr.retrieval( question, embd_mdl, @@ -72,6 +72,6 @@ def retrieval(tenant_id): if str(e).find("not_found") > 0: return build_error_result( message='No chunk found! Check the chunk status please!', - code=RetCode.NOT_FOUND + code=settings.RetCode.NOT_FOUND ) - return build_error_result(message=str(e), code=RetCode.SERVER_ERROR) + return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR) diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index f9eb84271..faae6333c 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc from rag.nlp import rag_tokenizer from api.db import LLMType, ParserType from api.db.services.llm_service import TenantLLMService -from api.settings import kg_retrievaler +from api import settings import hashlib import re from api.utils.api_utils import token_required @@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.settings import RetCode, retrievaler +from api import settings from api.utils.api_utils import construct_json_result, get_parser_config from rag.nlp import search from rag.utils import rmSpace -from api.settings import docStoreConn from rag.utils.storage_factory import STORAGE_IMPL import os @@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id): """ if "file" not in request.files: return get_error_data_result( - message="No file part!", code=RetCode.ARGUMENT_ERROR + message="No file part!", code=settings.RetCode.ARGUMENT_ERROR ) file_objs = request.files.getlist("file") for file_obj in file_objs: if file_obj.filename == "": return get_result( - message="No file selected!", code=RetCode.ARGUMENT_ERROR + message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR ) # total size total_size = 0 @@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id): if total_size > MAX_TOTAL_FILE_SIZE: return get_result( message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", - code=RetCode.ARGUMENT_ERROR, + code=settings.RetCode.ARGUMENT_ERROR, ) e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: raise LookupError(f"Can't find the dataset with ID {dataset_id}!") err, files = FileService.upload_document(kb, file_objs, tenant_id) if err: - return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR) + return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR) # rename key's name renamed_doc_list = [] for file in files: @@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id): if "name" in req and req["name"] != doc.name: if ( - pathlib.Path(req["name"].lower()).suffix - != pathlib.Path(doc.name.lower()).suffix + pathlib.Path(req["name"].lower()).suffix + != pathlib.Path(doc.name.lower()).suffix ): return get_result( message="The extension of file can't be changed", - code=RetCode.ARGUMENT_ERROR, + code=settings.RetCode.ARGUMENT_ERROR, ) for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): if d.name == req["name"]: @@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id): ) if not e: return get_error_data_result(message="Document not found!") - docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) return get_result() @@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id): file_stream = STORAGE_IMPL.get(doc_id, doc_location) if not file_stream: return construct_json_result( - message="This file is empty.", code=RetCode.DATA_ERROR + message="This file is empty.", code=settings.RetCode.DATA_ERROR ) file = BytesIO(file_stream) # Use send_file with a proper filename and MIME type @@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id): errors += str(e) if errors: - return get_result(message=errors, code=RetCode.SERVER_ERROR) + return get_result(message=errors, code=settings.RetCode.SERVER_ERROR) return get_result() @@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id): info["chunk_num"] = 0 info["token_num"] = 0 DocumentService.update_by_id(id, info) - docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) TaskService.filter_delete([Task.doc_id == id]) e, doc = DocumentService.get_by_id(id) doc = doc.to_dict() @@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id): ) info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) - docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) return get_result() @@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id): res = {"total": 0, "chunks": [], "doc": renamed_doc} origin_chunks = [] - if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): - sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) + if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): + sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, + highlight=True) res["total"] = sres.total sign = 0 for id in sres.ids: @@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id): v, c = embd_mdl.encode([doc.name, req["content"]]) v = 0.1 * v[0] + 0.9 * v[1] d["q_%d_vec" % len(v)] = v.tolist() - docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) + settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) # rename keys @@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): condition = {"doc_id": document_id} if "chunk_ids" in req: condition["id"] = req["chunk_ids"] - chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) + chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) if chunk_number != 0: DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]): @@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): schema: type: object """ - chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) if chunk is None: return get_error_data_result(f"Can't find this chunk {chunk_id}") if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): @@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d["q_%d_vec" % len(v)] = v.tolist() - docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) return get_result() @@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id): if len(embd_nms) != 1: return get_result( message='Datasets use different embedding models."', - code=RetCode.AUTHENTICATION_ERROR, + code=settings.RetCode.AUTHENTICATION_ERROR, ) if "question" not in req: return get_error_data_result("`question` is required.") @@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id): chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler + retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler ranks = retr.retrieval( question, embd_mdl, @@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id): if str(e).find("not_found") > 0: return get_result( message="No chunk found! Check the chunk status please!", - code=RetCode.DATA_ERROR, + code=settings.RetCode.DATA_ERROR, ) - return server_error_response(e) \ No newline at end of file + return server_error_response(e) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index 964aa0002..5e67ad43e 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -22,7 +22,7 @@ from api.db.db_models import APIToken from api.db.services.api_service import APITokenService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import UserTenantService -from api.settings import DATABASE_TYPE +from api import settings from api.utils import current_timestamp, datetime_format from api.utils.api_utils import ( get_json_result, @@ -31,7 +31,6 @@ from api.utils.api_utils import ( generate_confirmation_token, ) from api.versions import get_ragflow_version -from api.settings import docStoreConn from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE from timeit import default_timer as timer @@ -98,7 +97,7 @@ def status(): res = {} st = timer() try: - res["doc_store"] = docStoreConn.health() + res["doc_store"] = settings.docStoreConn.health() res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) except Exception as e: res["doc_store"] = { @@ -128,13 +127,13 @@ def status(): try: KnowledgebaseService.get_by_id("x") res["database"] = { - "database": DATABASE_TYPE.lower(), + "database": settings.DATABASE_TYPE.lower(), "status": "green", "elapsed": "{:.1f}".format((timer() - st) * 1000.0), } except Exception as e: res["database"] = { - "database": DATABASE_TYPE.lower(), + "database": settings.DATABASE_TYPE.lower(), "status": "red", "elapsed": "{:.1f}".format((timer() - st) * 1000.0), "error": str(e), diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 08ab1227d..e6b576860 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -38,20 +38,7 @@ from api.utils import ( datetime_format, ) from api.db import UserTenantRole, FileType -from api.settings import ( - RetCode, - GITHUB_OAUTH, - FEISHU_OAUTH, - CHAT_MDL, - EMBEDDING_MDL, - ASR_MDL, - IMAGE2TEXT_MDL, - PARSERS, - API_KEY, - LLM_FACTORY, - LLM_BASE_URL, - RERANK_MDL, -) +from api import settings from api.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.file_service import FileService from api.utils.api_utils import get_json_result, construct_response @@ -90,7 +77,7 @@ def login(): """ if not request.json: return get_json_result( - data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" + data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" ) email = request.json.get("email", "") @@ -98,7 +85,7 @@ def login(): if not users: return get_json_result( data=False, - code=RetCode.AUTHENTICATION_ERROR, + code=settings.RetCode.AUTHENTICATION_ERROR, message=f"Email: {email} is not registered!", ) @@ -107,7 +94,7 @@ def login(): password = decrypt(password) except BaseException: return get_json_result( - data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password" + data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password" ) user = UserService.query_user(email, password) @@ -123,7 +110,7 @@ def login(): else: return get_json_result( data=False, - code=RetCode.AUTHENTICATION_ERROR, + code=settings.RetCode.AUTHENTICATION_ERROR, message="Email and password do not match!", ) @@ -150,10 +137,10 @@ def github_callback(): import requests res = requests.post( - GITHUB_OAUTH.get("url"), + settings.GITHUB_OAUTH.get("url"), data={ - "client_id": GITHUB_OAUTH.get("client_id"), - "client_secret": GITHUB_OAUTH.get("secret_key"), + "client_id": settings.GITHUB_OAUTH.get("client_id"), + "client_secret": settings.GITHUB_OAUTH.get("secret_key"), "code": request.args.get("code"), }, headers={"Accept": "application/json"}, @@ -235,11 +222,11 @@ def feishu_callback(): import requests app_access_token_res = requests.post( - FEISHU_OAUTH.get("app_access_token_url"), + settings.FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps( { - "app_id": FEISHU_OAUTH.get("app_id"), - "app_secret": FEISHU_OAUTH.get("app_secret"), + "app_id": settings.FEISHU_OAUTH.get("app_id"), + "app_secret": settings.FEISHU_OAUTH.get("app_secret"), } ), headers={"Content-Type": "application/json; charset=utf-8"}, @@ -249,10 +236,10 @@ def feishu_callback(): return redirect("/?error=%s" % app_access_token_res) res = requests.post( - FEISHU_OAUTH.get("user_access_token_url"), + settings.FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps( { - "grant_type": FEISHU_OAUTH.get("grant_type"), + "grant_type": settings.FEISHU_OAUTH.get("grant_type"), "code": request.args.get("code"), } ), @@ -405,11 +392,11 @@ def setting_user(): if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash( - current_user.password, decrypt(request_data["password"]) + current_user.password, decrypt(request_data["password"]) ): return get_json_result( data=False, - code=RetCode.AUTHENTICATION_ERROR, + code=settings.RetCode.AUTHENTICATION_ERROR, message="Password error!", ) @@ -438,7 +425,7 @@ def setting_user(): except Exception as e: logging.exception(e) return get_json_result( - data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR + data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR ) @@ -497,12 +484,12 @@ def user_register(user_id, user): tenant = { "id": user_id, "name": user["nickname"] + "‘s Kingdom", - "llm_id": CHAT_MDL, - "embd_id": EMBEDDING_MDL, - "asr_id": ASR_MDL, - "parser_ids": PARSERS, - "img2txt_id": IMAGE2TEXT_MDL, - "rerank_id": RERANK_MDL, + "llm_id": settings.CHAT_MDL, + "embd_id": settings.EMBEDDING_MDL, + "asr_id": settings.ASR_MDL, + "parser_ids": settings.PARSERS, + "img2txt_id": settings.IMAGE2TEXT_MDL, + "rerank_id": settings.RERANK_MDL, } usr_tenant = { "tenant_id": user_id, @@ -522,15 +509,15 @@ def user_register(user_id, user): "location": "", } tenant_llm = [] - for llm in LLMService.query(fid=LLM_FACTORY): + for llm in LLMService.query(fid=settings.LLM_FACTORY): tenant_llm.append( { "tenant_id": user_id, - "llm_factory": LLM_FACTORY, + "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, - "api_key": API_KEY, - "api_base": LLM_BASE_URL, + "api_key": settings.API_KEY, + "api_base": settings.LLM_BASE_URL, } ) @@ -582,7 +569,7 @@ def user_add(): return get_json_result( data=False, message=f"Invalid email address: {email_address}!", - code=RetCode.OPERATING_ERROR, + code=settings.RetCode.OPERATING_ERROR, ) # Check if the email address is already used @@ -590,7 +577,7 @@ def user_add(): return get_json_result( data=False, message=f"Email: {email_address} has already registered!", - code=RetCode.OPERATING_ERROR, + code=settings.RetCode.OPERATING_ERROR, ) # Construct user info data @@ -625,7 +612,7 @@ def user_add(): return get_json_result( data=False, message=f"User registration failure, error: {str(e)}", - code=RetCode.EXCEPTION_ERROR, + code=settings.RetCode.EXCEPTION_ERROR, ) diff --git a/api/db/db_models.py b/api/db/db_models.py index 7c6325456..be3aafb1f 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -31,7 +31,7 @@ from peewee import ( ) from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from api.db import SerializedType, ParserType -from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE +from api import settings from api import utils def singleton(cls, *args, **kw): @@ -62,7 +62,7 @@ class TextFieldType(Enum): class LongTextField(TextField): - field_type = TextFieldType[DATABASE_TYPE.upper()].value + field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value class JSONField(LongTextField): @@ -282,9 +282,9 @@ class DatabaseMigrator(Enum): @singleton class BaseDataBase: def __init__(self): - database_config = DATABASE.copy() + database_config = settings.DATABASE.copy() db_name = database_config.pop("name") - self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config) + self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config) logging.info('init database on cluster mode successfully') class PostgresDatabaseLock: @@ -385,7 +385,7 @@ class DatabaseLock(Enum): DB = BaseDataBase().database_connection -DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value +DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value def close_connection(): @@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin): return self.email def get_id(self): - jwt = Serializer(secret_key=SECRET_KEY) + jwt = Serializer(secret_key=settings.SECRET_KEY) return jwt.dumps(str(self.access_token)) class Meta: @@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel): def migrate_db(): with DB.transaction(): - migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB) + migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) try: migrate( migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", diff --git a/api/db/init_data.py b/api/db/init_data.py index c04169ad3..f1d468a2a 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.user_service import TenantService, UserTenantService -from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL +from api import settings from api.utils.file_utils import get_project_base_directory @@ -51,11 +51,11 @@ def init_superuser(): tenant = { "id": user_info["id"], "name": user_info["nickname"] + "‘s Kingdom", - "llm_id": CHAT_MDL, - "embd_id": EMBEDDING_MDL, - "asr_id": ASR_MDL, - "parser_ids": PARSERS, - "img2txt_id": IMAGE2TEXT_MDL + "llm_id": settings.CHAT_MDL, + "embd_id": settings.EMBEDDING_MDL, + "asr_id": settings.ASR_MDL, + "parser_ids": settings.PARSERS, + "img2txt_id": settings.IMAGE2TEXT_MDL } usr_tenant = { "tenant_id": user_info["id"], @@ -64,10 +64,11 @@ def init_superuser(): "role": UserTenantRole.OWNER } tenant_llm = [] - for llm in LLMService.query(fid=LLM_FACTORY): + for llm in LLMService.query(fid=settings.LLM_FACTORY): tenant_llm.append( - {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, - "api_key": API_KEY, "api_base": LLM_BASE_URL}) + {"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name, + "model_type": llm.model_type, + "api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL}) if not UserService.save(**user_info): logging.error("can't init admin.") @@ -80,7 +81,7 @@ def init_superuser(): chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) msg = chat_mdl.chat(system="", history=[ - {"role": "user", "content": "Hello!"}], gen_conf={}) + {"role": "user", "content": "Hello!"}], gen_conf={}) if msg.find("ERROR: ") == 0: logging.error( "'{}' dosen't work. {}".format( @@ -179,7 +180,7 @@ def init_web_data(): start_time = time.time() init_llm_factory() - #if not UserService.get_all().count(): + # if not UserService.get_all().count(): # init_superuser() add_graph_templates() diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index c586309b9..31f948673 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle -from api.settings import retrievaler, kg_retrievaler +from api import settings from rag.app.resume import forbidden_select_fields4resume from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder @@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs): return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) - retr = retrievaler if not is_kg else kg_retrievaler + retr = settings.retrievaler if not is_kg else settings.kg_retrievaler questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None @@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 - return retrievaler.sql_retrieval(sql, format="json"), sql + return settings.retrievaler.sql_retrieval(sql, format="json"), sql tbl, sql = get_table() if tbl is None: @@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id): embd_nms = list(set([kb.embd_id for kb in kbs])) is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) - retr = retrievaler if not is_kg else kg_retrievaler + retr = settings.retrievaler if not is_kg else settings.kg_retrievaler embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a2681ba29..b5c758d57 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -26,7 +26,7 @@ from io import BytesIO from peewee import fn from api.db.db_utils import bulk_insert_into_db -from api.settings import docStoreConn +from api import settings from api.utils import current_timestamp, get_format_time, get_uuid from graphrag.mind_map_extractor import MindMapExtractor from rag.settings import SVR_QUEUE_NAME @@ -108,7 +108,7 @@ class DocumentService(CommonService): @classmethod @DB.connection_context() def remove_document(cls, doc, tenant_id): - docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) cls.clear_chunk_num(doc.id) return cls.delete_by_id(doc.id) @@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): d["q_%d_vec" % len(v)] = v for b in range(0, len(cks), es_bulk_size): if try_create_idx: - if not docStoreConn.indexExist(idxnm, kb_id): - docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) + if not settings.docStoreConn.indexExist(idxnm, kb_id): + settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) try_create_idx = False - docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) + settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) DocumentService.increment_chunk_num( doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index f832bdf53..09dd9fd33 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -33,12 +33,10 @@ import traceback from concurrent.futures import ThreadPoolExecutor from werkzeug.serving import run_simple +from api import settings from api.apps import app from api.db.runtime_config import RuntimeConfig from api.db.services.document_service import DocumentService -from api.settings import ( - HOST, HTTP_PORT -) from api import utils from api.db.db_models import init_database_tables as init_web_db @@ -72,6 +70,7 @@ if __name__ == '__main__': f'project base: {utils.file_utils.get_project_base_directory()}' ) show_configs() + settings.init_settings() # init db init_web_db() @@ -96,7 +95,7 @@ if __name__ == '__main__': logging.info("run on debug mode") RuntimeConfig.init_env() - RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) + RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT) thread = ThreadPoolExecutor(max_workers=1) thread.submit(update_progress) @@ -105,8 +104,8 @@ if __name__ == '__main__': try: logging.info("RAGFlow HTTP server start...") run_simple( - hostname=HOST, - port=HTTP_PORT, + hostname=settings.HOST_IP, + port=settings.HOST_PORT, application=app, threaded=True, use_reloader=RuntimeConfig.DEBUG, diff --git a/api/settings.py b/api/settings.py index e08970468..0abf548e8 100644 --- a/api/settings.py +++ b/api/settings.py @@ -30,114 +30,157 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0")) REQUEST_WAIT_SEC = 2 REQUEST_MAX_WAIT_SEC = 300 - -LLM = get_base_config("user_default_llm", {}) -LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") -LLM_BASE_URL = LLM.get("base_url") - -CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = "" -if not LIGHTEN: - default_llm = { - "Tongyi-Qianwen": { - "chat_model": "qwen-plus", - "embedding_model": "text-embedding-v2", - "image2text_model": "qwen-vl-max", - "asr_model": "paraformer-realtime-8k-v1", - }, - "OpenAI": { - "chat_model": "gpt-3.5-turbo", - "embedding_model": "text-embedding-ada-002", - "image2text_model": "gpt-4-vision-preview", - "asr_model": "whisper-1", - }, - "Azure-OpenAI": { - "chat_model": "gpt-35-turbo", - "embedding_model": "text-embedding-ada-002", - "image2text_model": "gpt-4-vision-preview", - "asr_model": "whisper-1", - }, - "ZHIPU-AI": { - "chat_model": "glm-3-turbo", - "embedding_model": "embedding-2", - "image2text_model": "glm-4v", - "asr_model": "", - }, - "Ollama": { - "chat_model": "qwen-14B-chat", - "embedding_model": "flag-embedding", - "image2text_model": "", - "asr_model": "", - }, - "Moonshot": { - "chat_model": "moonshot-v1-8k", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "DeepSeek": { - "chat_model": "deepseek-chat", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "VolcEngine": { - "chat_model": "", - "embedding_model": "", - "image2text_model": "", - "asr_model": "", - }, - "BAAI": { - "chat_model": "", - "embedding_model": "BAAI/bge-large-zh-v1.5", - "image2text_model": "", - "asr_model": "", - "rerank_model": "BAAI/bge-reranker-v2-m3", - } - } - - if LLM_FACTORY: - CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" - ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" - IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" - EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" - RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" - -API_KEY = LLM.get("api_key", "") -PARSERS = LLM.get( - "parsers", - "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") - -HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") -HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") - -SECRET_KEY = get_base_config( - RAG_FLOW_SERVICE_NAME, - {}).get("secret_key", str(date.today())) +LLM = None +LLM_FACTORY = None +LLM_BASE_URL = None +CHAT_MDL = "" +EMBEDDING_MDL = "" +RERANK_MDL = "" +ASR_MDL = "" +IMAGE2TEXT_MDL = "" +API_KEY = None +PARSERS = None +HOST_IP = None +HOST_PORT = None +SECRET_KEY = None DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') DATABASE = decrypt_database_config(name=DATABASE_TYPE) # authentication -AUTHENTICATION_CONF = get_base_config("authentication", {}) +AUTHENTICATION_CONF = None # client -CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( - "client", {}).get( - "switch", False) -HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") -GITHUB_OAUTH = get_base_config("oauth", {}).get("github") -FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") +CLIENT_AUTHENTICATION = None +HTTP_APP_KEY = None +GITHUB_OAUTH = None +FEISHU_OAUTH = None -DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") -if DOC_ENGINE == "elasticsearch": - docStoreConn = rag.utils.es_conn.ESConnection() -elif DOC_ENGINE == "infinity": - docStoreConn = rag.utils.infinity_conn.InfinityConnection() -else: - raise Exception(f"Not supported doc engine: {DOC_ENGINE}") +DOC_ENGINE = None +docStoreConn = None -retrievaler = search.Dealer(docStoreConn) -kg_retrievaler = kg_search.KGSearch(docStoreConn) +retrievaler = None +kg_retrievaler = None + + +def init_settings(): + global LLM, LLM_FACTORY, LLM_BASE_URL + LLM = get_base_config("user_default_llm", {}) + LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") + LLM_BASE_URL = LLM.get("base_url") + + global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL + if not LIGHTEN: + default_llm = { + "Tongyi-Qianwen": { + "chat_model": "qwen-plus", + "embedding_model": "text-embedding-v2", + "image2text_model": "qwen-vl-max", + "asr_model": "paraformer-realtime-8k-v1", + }, + "OpenAI": { + "chat_model": "gpt-3.5-turbo", + "embedding_model": "text-embedding-ada-002", + "image2text_model": "gpt-4-vision-preview", + "asr_model": "whisper-1", + }, + "Azure-OpenAI": { + "chat_model": "gpt-35-turbo", + "embedding_model": "text-embedding-ada-002", + "image2text_model": "gpt-4-vision-preview", + "asr_model": "whisper-1", + }, + "ZHIPU-AI": { + "chat_model": "glm-3-turbo", + "embedding_model": "embedding-2", + "image2text_model": "glm-4v", + "asr_model": "", + }, + "Ollama": { + "chat_model": "qwen-14B-chat", + "embedding_model": "flag-embedding", + "image2text_model": "", + "asr_model": "", + }, + "Moonshot": { + "chat_model": "moonshot-v1-8k", + "embedding_model": "", + "image2text_model": "", + "asr_model": "", + }, + "DeepSeek": { + "chat_model": "deepseek-chat", + "embedding_model": "", + "image2text_model": "", + "asr_model": "", + }, + "VolcEngine": { + "chat_model": "", + "embedding_model": "", + "image2text_model": "", + "asr_model": "", + }, + "BAAI": { + "chat_model": "", + "embedding_model": "BAAI/bge-large-zh-v1.5", + "image2text_model": "", + "asr_model": "", + "rerank_model": "BAAI/bge-reranker-v2-m3", + } + } + + if LLM_FACTORY: + CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}" + ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}" + IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}" + EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" + RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" + + global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY + API_KEY = LLM.get("api_key", "") + PARSERS = LLM.get( + "parsers", + "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") + + HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") + HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") + + SECRET_KEY = get_base_config( + RAG_FLOW_SERVICE_NAME, + {}).get("secret_key", str(date.today())) + + global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH + # authentication + AUTHENTICATION_CONF = get_base_config("authentication", {}) + + # client + CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get( + "client", {}).get( + "switch", False) + HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key") + GITHUB_OAUTH = get_base_config("oauth", {}).get("github") + FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") + + global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler + DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") + if DOC_ENGINE == "elasticsearch": + docStoreConn = rag.utils.es_conn.ESConnection() + elif DOC_ENGINE == "infinity": + docStoreConn = rag.utils.infinity_conn.InfinityConnection() + else: + raise Exception(f"Not supported doc engine: {DOC_ENGINE}") + + retrievaler = search.Dealer(docStoreConn) + kg_retrievaler = kg_search.KGSearch(docStoreConn) + +def get_host_ip(): + global HOST_IP + return HOST_IP + + +def get_host_port(): + global HOST_PORT + return HOST_PORT class CustomEnum(Enum): diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 37f4a3c92..7d1fc7725 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer from werkzeug.http import HTTP_STATUS_CODES from api.db.db_models import APIToken -from api.settings import ( - REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, - CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY -) -from api.settings import RetCode +from api import settings + +from api import settings from api.utils import CustomJSONEncoder, get_uuid from api.utils import json_dumps @@ -59,13 +57,13 @@ def request(**kwargs): {}).items()} prepped = requests.Request(**kwargs).prepare() - if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: + if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY: timestamp = str(round(time() * 1000)) nonce = str(uuid1()) - signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([ + signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([ timestamp.encode('ascii'), nonce.encode('ascii'), - HTTP_APP_KEY.encode('ascii'), + settings.HTTP_APP_KEY.encode('ascii'), prepped.path_url.encode('ascii'), prepped.body if kwargs.get('json') else b'', urlencode( @@ -79,7 +77,7 @@ def request(**kwargs): prepped.headers.update({ 'TIMESTAMP': timestamp, 'NONCE': nonce, - 'APP-KEY': HTTP_APP_KEY, + 'APP-KEY': settings.HTTP_APP_KEY, 'SIGNATURE': signature, }) @@ -89,7 +87,7 @@ def request(**kwargs): def get_exponential_backoff_interval(retries, full_jitter=False): """Calculate the exponential backoff wait time.""" # Will be zero if factor equals 0 - countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) + countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries)) # Full jitter according to # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ if full_jitter: @@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False): return max(0, countdown) -def get_data_error_result(code=RetCode.DATA_ERROR, +def get_data_error_result(code=settings.RetCode.DATA_ERROR, message='Sorry! Data missing!'): import re result_dict = { @@ -126,8 +124,8 @@ def server_error_response(e): pass if len(e.args) > 1: return get_json_result( - code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) - return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) + code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) + return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) def error_response(response_code, message=None): @@ -168,7 +166,7 @@ def validate_request(*args, **kwargs): error_string += "required argument values: {}".format( ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) return get_json_result( - code=RetCode.ARGUMENT_ERROR, message=error_string) + code=settings.RetCode.ARGUMENT_ERROR, message=error_string) return func(*_args, **_kwargs) return decorated_function @@ -193,7 +191,7 @@ def send_file_in_mem(data, filename): return send_file(f, as_attachment=True, attachment_filename=filename) -def get_json_result(code=RetCode.SUCCESS, message='success', data=None): +def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): response = {"code": code, "message": message, "data": data} return jsonify(response) @@ -204,7 +202,7 @@ def apikey_required(func): objs = APIToken.query(token=token) if not objs: return build_error_result( - message='API-KEY is invalid!', code=RetCode.FORBIDDEN + message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN ) kwargs['tenant_id'] = objs[0].tenant_id return func(*args, **kwargs) @@ -212,14 +210,14 @@ def apikey_required(func): return decorated_function -def build_error_result(code=RetCode.FORBIDDEN, message='success'): +def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'): response = {"code": code, "message": message} response = jsonify(response) response.status_code = code return response -def construct_response(code=RetCode.SUCCESS, +def construct_response(code=settings.RetCode.SUCCESS, message='success', data=None, auth=None): result_dict = {"code": code, "message": message, "data": data} response_dict = {} @@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS, return response -def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): +def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'): import re result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} response = {} @@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): return jsonify(response) -def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): +def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None): if data is None: return jsonify({"code": code, "message": message}) else: @@ -262,12 +260,12 @@ def construct_error_response(e): logging.exception(e) try: if e.code == 401: - return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) + return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e)) except BaseException: pass if len(e.args) > 1: - return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) - return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) + return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) + return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e)) def token_required(func): @@ -280,7 +278,7 @@ def token_required(func): objs = APIToken.query(token=token) if not objs: return get_json_result( - data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR + data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR ) kwargs['tenant_id'] = objs[0].tenant_id return func(*args, **kwargs) @@ -288,7 +286,7 @@ def token_required(func): return decorated_function -def get_result(code=RetCode.SUCCESS, message="", data=None): +def get_result(code=settings.RetCode.SUCCESS, message="", data=None): if code == 0: if data is not None: response = {"code": code, "data": data} @@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None): return jsonify(response) -def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR, +def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR, ): import re result_dict = { diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 2f2bde76b..2f70ba7f1 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -24,7 +24,7 @@ import numpy as np from timeit import default_timer as timer from pypdf import PdfReader as pdf2_read -from api.settings import LIGHTEN +from api import settings from api.utils.file_utils import get_project_base_directory from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from rag.nlp import rag_tokenizer @@ -41,7 +41,7 @@ class RAGFlowPdfParser: self.tbl_det = TableStructureRecognizer() self.updown_cnt_mdl = xgb.Booster() - if not LIGHTEN: + if not settings.LIGHTEN: try: import torch if torch.cuda.is_available(): diff --git a/graphrag/claim_extractor.py b/graphrag/claim_extractor.py index d3495a791..b202e50a2 100644 --- a/graphrag/claim_extractor.py +++ b/graphrag/claim_extractor.py @@ -252,13 +252,13 @@ if __name__ == "__main__": from api.db import LLMType from api.db.services.llm_service import LLMBundle - from api.settings import retrievaler + from api import settings from api.db.services.knowledgebase_service import KnowledgebaseService kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) - docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] + docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] info = { "input_text": docs, "entity_specs": "organization, person", diff --git a/graphrag/smoke.py b/graphrag/smoke.py index 3d0ae370a..3a0115333 100644 --- a/graphrag/smoke.py +++ b/graphrag/smoke.py @@ -30,14 +30,14 @@ if __name__ == "__main__": from api.db import LLMType from api.db.services.llm_service import LLMBundle - from api.settings import retrievaler + from api import settings from api.db.services.knowledgebase_service import KnowledgebaseService kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) docs = [d["content_with_weight"] for d in - retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] + settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] graph = ex(docs) er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) diff --git a/rag/benchmark.py b/rag/benchmark.py index b7d536efc..2688d3630 100644 --- a/rag/benchmark.py +++ b/rag/benchmark.py @@ -23,7 +23,7 @@ from collections import defaultdict from api.db import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.knowledgebase_service import KnowledgebaseService -from api.settings import retrievaler, docStoreConn +from api import settings from api.utils import get_uuid from rag.nlp import tokenize, search from ranx import evaluate @@ -52,7 +52,7 @@ class Benchmark: run = defaultdict(dict) query_list = list(qrels.keys()) for query in query_list: - ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, + ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, 0.0, self.vector_similarity_weight) if len(ranks["chunks"]) == 0: print(f"deleted query: {query}") @@ -81,9 +81,9 @@ class Benchmark: def init_index(self, vector_size: int): if self.initialized_index: return - if docStoreConn.indexExist(self.index_name, self.kb_id): - docStoreConn.deleteIdx(self.index_name, self.kb_id) - docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) + if settings.docStoreConn.indexExist(self.index_name, self.kb_id): + settings.docStoreConn.deleteIdx(self.index_name, self.kb_id) + settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) self.initialized_index = True def ms_marco_index(self, file_path, index_name): @@ -118,13 +118,13 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - docStoreConn.insert(docs, self.index_name, self.kb_id) + settings.docStoreConn.insert(docs, self.index_name, self.kb_id) docs = [] if docs: docs, vector_size = self.embedding(docs) self.init_index(vector_size) - docStoreConn.insert(docs, self.index_name, self.kb_id) + settings.docStoreConn.insert(docs, self.index_name, self.kb_id) return qrels, texts def trivia_qa_index(self, file_path, index_name): @@ -159,12 +159,12 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - docStoreConn.insert(docs,self.index_name) + settings.docStoreConn.insert(docs,self.index_name) docs = [] docs, vector_size = self.embedding(docs) self.init_index(vector_size) - docStoreConn.insert(docs, self.index_name) + settings.docStoreConn.insert(docs, self.index_name) return qrels, texts def miracl_index(self, file_path, corpus_path, index_name): @@ -214,12 +214,12 @@ class Benchmark: docs_count += len(docs) docs, vector_size = self.embedding(docs) self.init_index(vector_size) - docStoreConn.insert(docs, self.index_name) + settings.docStoreConn.insert(docs, self.index_name) docs = [] docs, vector_size = self.embedding(docs) self.init_index(vector_size) - docStoreConn.insert(docs, self.index_name) + settings.docStoreConn.insert(docs, self.index_name) return qrels, texts def save_results(self, qrels, run, texts, dataset, file_path): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index b167b75e9..de65ed691 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -28,7 +28,7 @@ from openai import OpenAI import numpy as np import asyncio -from api.settings import LIGHTEN +from api import settings from api.utils.file_utils import get_home_cache_dir from rag.utils import num_tokens_from_string, truncate import google.generativeai as genai @@ -60,7 +60,7 @@ class DefaultEmbedding(Base): ^_- """ - if not LIGHTEN and not DefaultEmbedding._model: + if not settings.LIGHTEN and not DefaultEmbedding._model: with DefaultEmbedding._model_lock: from FlagEmbedding import FlagModel import torch @@ -248,7 +248,7 @@ class FastEmbed(Base): threads: Optional[int] = None, **kwargs, ): - if not LIGHTEN and not FastEmbed._model: + if not settings.LIGHTEN and not FastEmbed._model: from fastembed import TextEmbedding self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) @@ -294,7 +294,7 @@ class YoudaoEmbed(Base): _client = None def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): - if not LIGHTEN and not YoudaoEmbed._client: + if not settings.LIGHTEN and not YoudaoEmbed._client: from BCEmbedding import EmbeddingModel as qanthing try: logging.info("LOADING BCE...") diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index 45bd10aeb..cac7b342d 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -23,7 +23,7 @@ import os from abc import ABC import numpy as np -from api.settings import LIGHTEN +from api import settings from api.utils.file_utils import get_home_cache_dir from rag.utils import num_tokens_from_string, truncate import json @@ -57,7 +57,7 @@ class DefaultRerank(Base): ^_- """ - if not LIGHTEN and not DefaultRerank._model: + if not settings.LIGHTEN and not DefaultRerank._model: import torch from FlagEmbedding import FlagReranker with DefaultRerank._model_lock: @@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank): _model_lock = threading.Lock() def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): - if not LIGHTEN and not YoudaoRerank._model: + if not settings.LIGHTEN and not YoudaoRerank._model: from BCEmbedding import RerankerModel with YoudaoRerank._model_lock: if not YoudaoRerank._model: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index fefd7958c..76568d927 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -16,6 +16,7 @@ import logging import sys from api.utils.log_utils import initRootLogger + CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] initRootLogger(f"task_executor_{CONSUMER_NO}") for module in ["pdfminer"]: @@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService from api.db.services.llm_service import LLMBundle from api.db.services.task_service import TaskService from api.db.services.file2document_service import File2DocumentService -from api.settings import retrievaler, docStoreConn +from api import settings from api.db.db_models import close_connection -from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email +from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ + knowledge_graph, email from rag.nlp import search, rag_tokenizer from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME @@ -88,6 +90,7 @@ PENDING_TASKS = 0 HEAD_CREATED_AT = "" HEAD_DETAIL = "" + def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): global PAYLOAD if prog is not None and prog < 0: @@ -171,7 +174,8 @@ def build(row): "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) except TimeoutError: callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") - logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) + logging.exception( + "Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) return except Exception as e: if re.search("(No such file|not found)", str(e)): @@ -188,7 +192,7 @@ def build(row): logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"])) except Exception as e: callback(-1, "Internal server error while chunking: %s" % - str(e).replace("'", "")) + str(e).replace("'", "")) logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"])) return @@ -226,7 +230,8 @@ def build(row): STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) el += timer() - st except Exception: - logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) + logging.exception( + "Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) del d["image"] @@ -241,7 +246,7 @@ def build(row): d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"], row["parser_config"]["auto_keywords"]).split(",") d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) - callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st)) + callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st)) if row["parser_config"].get("auto_questions", 0): st = timer() @@ -255,14 +260,14 @@ def build(row): d["content_ltks"] += " " + qst if "content_sm_ltks" in d: d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst) - callback(msg="Question generation completed in {:.2f}s".format(timer()-st)) + callback(msg="Question generation completed in {:.2f}s".format(timer() - st)) return docs def init_kb(row, vector_size: int): idxnm = search.index_name(row["tenant_id"]) - return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) + return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) def embedding(docs, mdl, parser_config=None, callback=None): @@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): vector_size = len(vts[0]) vctr_nm = "q_%d_vec" % vector_size chunks = [] - for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]): + for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], + fields=["content_with_weight", vctr_nm]): chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) raptor = Raptor( @@ -384,7 +390,8 @@ def main(): # TODO: exception handler ## set_progress(r["did"], -1, "ERROR: ") callback( - msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st) + msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), + timer() - st) ) st = timer() try: @@ -403,18 +410,18 @@ def main(): es_r = "" es_bulk_size = 4 for b in range(0, len(cks), es_bulk_size): - es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) + es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) if b % 128 == 0: callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) if es_r: callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") - docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) + settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) logging.error('Insert chunk error: ' + str(es_r)) else: if TaskService.do_cancel(r["id"]): - docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) + settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) continue callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) callback(1., "Done!") @@ -435,7 +442,7 @@ def report_status(): if PENDING_TASKS > 0: head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) if head_info is not None: - seconds = int(head_info[0].split("-")[0])/1000 + seconds = int(head_info[0].split("-")[0]) / 1000 HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() HEAD_DETAIL = head_info[1] @@ -452,7 +459,7 @@ def report_status(): REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") - expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30) + expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) if expired > 0: REDIS_CONN.zpopmin(CONSUMER_NAME, expired) except Exception: