Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai 2024-11-15 17:30:56 +08:00 committed by GitHub
parent ac033b62cf
commit 1e90a1bf36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 452 additions and 411 deletions

View File

@ -19,7 +19,7 @@ import pandas as pd
from api.db import LLMType from api.db import LLMType
from api.db.services.dialog_service import message_fit_in from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api import settings
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
@ -63,14 +63,16 @@ class Generate(ComponentBase):
component_name = "Generate" component_name = "Generate"
def get_dependent_components(self): def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0] cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts return cpnts
def set_cite(self, retrieval_res, answer): def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True) retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
if "empty_response" in retrieval_res.columns: if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True) retrieval_res["empty_response"].fillna("", inplace=True)
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()], answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()], [ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7, self._canvas.get_embedding_model()), tkweight=0.7,
@ -127,12 +129,14 @@ class Generate(ComponentBase):
else: else:
if cpn.component_name.lower() == "retrieval": if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out) retrieval_res.append(out)
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]]) kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]}) self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res: if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True) retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([]) else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items(): for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt) prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)

View File

@ -21,7 +21,7 @@ import pandas as pd
from api.db import LLMType from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api import settings
from agent.component.base import ComponentBase, ComponentParamBase from agent.component.base import ComponentBase, ComponentParamBase
@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
if self._param.rerank_id: if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n, 1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl) aggs=False, rerank_mdl=rerank_mdl)

View File

@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
from flask_session import Session from flask_session import Session
from flask_login import LoginManager from flask_login import LoginManager
from api.settings import SECRET_KEY from api import settings
from api.settings import API_VERSION
from api.utils.api_utils import server_error_response from api.utils.api_utils import server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response) app.errorhandler(Exception)(server_error_response)
## convince for dev and debug ## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True # app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False app.config["SESSION_PERMANENT"] = False
@ -121,7 +119,7 @@ def register_page(page_path):
spec.loader.exec_module(page) spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name) page_name = getattr(page, "page_name", page_name)
url_prefix = ( url_prefix = (
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}" f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
) )
app.register_blueprint(page.manager, url_prefix=url_prefix) app.register_blueprint(page.manager, url_prefix=url_prefix)
@ -141,7 +139,7 @@ client_urls_prefix = [
@login_manager.request_loader @login_manager.request_loader
def load_user(web_request): def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY) jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = web_request.headers.get("Authorization") authorization = web_request.headers.get("Authorization")
if authorization: if authorization:
try: try:

View File

@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler from api import settings
from api.utils import get_uuid, current_timestamp, datetime_format from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token generate_confirmation_token
@ -141,7 +141,7 @@ def set_conversation():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
try: try:
if objs[0].source == "agent": if objs[0].source == "agent":
@ -183,7 +183,7 @@ def completion():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
@ -347,7 +347,7 @@ def get(conversation_id):
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
e, conv = API4ConversationService.get_by_id(conversation_id) e, conv = API4ConversationService.get_by_id(conversation_id)
@ -357,7 +357,7 @@ def get(conversation_id):
conv = conv.to_dict() conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Token is not valid for this conversation_id!"', return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR) code=settings.RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']: for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0: if referenct_i is None or len(referenct_i) == 0:
@ -378,7 +378,7 @@ def upload():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
kb_name = request.form.get("kb_name").strip() kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
@ -394,12 +394,12 @@ def upload():
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file = request.files['file'] file = request.files['file']
if file.filename == '': if file.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
root_folder = FileService.get_root_folder(tenant_id) root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"] pf_id = root_folder["id"]
@ -490,17 +490,17 @@ def upload_parse():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids) return get_json_result(data=doc_ids)
@ -513,7 +513,7 @@ def list_chunks():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
@ -531,7 +531,7 @@ def list_chunks():
) )
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids) res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [ res = [
{ {
"content": res_item["content_with_weight"], "content": res_item["content_with_weight"],
@ -553,7 +553,7 @@ def list_kb_docs():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
@ -585,6 +585,7 @@ def list_kb_docs():
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@manager.route('/document/infos', methods=['POST']) @manager.route('/document/infos', methods=['POST'])
@validate_request("doc_ids") @validate_request("doc_ids")
def docinfos(): def docinfos():
@ -592,7 +593,7 @@ def docinfos():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
doc_ids = req["doc_ids"] doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids) docs = DocumentService.get_by_ids(doc_ids)
@ -606,7 +607,7 @@ def document_rm():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
req = request.json req = request.json
@ -653,7 +654,7 @@ def document_rm():
errors += str(e) errors += str(e)
if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True) return get_json_result(data=True)
@ -668,7 +669,7 @@ def completion_faq():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"]) e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e: if not e:
@ -805,7 +806,7 @@ def retrieval():
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
kb_ids = req.get("kb_id", []) kb_ids = req.get("kb_id", [])
@ -822,7 +823,8 @@ def retrieval():
embd_nms = list(set([kb.embd_id for kb in kbs])) embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1: if len(embd_nms) != 1:
return get_json_result( return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR) data=False, message='Knowledge bases use different embedding models or does not exist."',
code=settings.RetCode.AUTHENTICATION_ERROR)
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
@ -833,7 +835,7 @@ def retrieval():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl) doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]: for c in ranks["chunks"]:
@ -843,5 +845,5 @@ def retrieval():
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR) code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)

View File

@ -19,7 +19,7 @@ from functools import partial
from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
@ -36,7 +36,8 @@ def templates():
@login_required @login_required
def canvas_list(): def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \ return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1) UserCanvasService.query(user_id=current_user.id)],
key=lambda x: x["update_time"] * -1)
) )
@ -48,7 +49,7 @@ def rm():
if not UserCanvasService.query(user_id=current_user.id, id=i): if not UserCanvasService.query(user_id=current_user.id, id=i):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i) UserCanvasService.delete_by_id(i)
return get_json_result(data=True) return get_json_result(data=True)
@ -72,7 +73,7 @@ def save():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req) UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req) return get_json_result(data=req)
@ -98,7 +99,7 @@ def run():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
if not isinstance(cvs.dsl, str): if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
@ -122,7 +123,8 @@ def run():
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
if stream: if stream:
assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow." assert isinstance(answer,
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
def sse(): def sse():
nonlocal answer, cvs nonlocal answer, cvs
@ -173,7 +175,7 @@ def reset():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]): if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.', data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset() canvas.reset()

View File

@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import hashlib import hashlib
import re import re
@manager.route('/list', methods=['POST']) @manager.route('/list', methods=['POST'])
@login_required @login_required
@validate_request("doc_id") @validate_request("doc_id")
@ -56,7 +57,7 @@ def list_chunk():
} }
if "available_int" in req: if "available_int" in req:
query["available_int"] = int(req["available_int"]) query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids: for id in sres.ids:
d = { d = {
@ -78,7 +79,7 @@ def list_chunk():
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found!', return get_json_result(data=False, message='No chunk found!',
code=RetCode.DATA_ERROR) code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)
@ -93,7 +94,7 @@ def get():
tenant_id = tenants[0].tenant_id tenant_id = tenants[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None: if chunk is None:
return server_error_response("Chunk not found") return server_error_response("Chunk not found")
k = [] k = []
@ -107,7 +108,7 @@ def get():
except Exception as e: except Exception as e:
if str(e).find("NotFoundError") >= 0: if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!', return get_json_result(data=False, message='Chunk not found!',
code=RetCode.DATA_ERROR) code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)
@ -154,7 +155,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -169,7 +170,7 @@ def switch():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])}, if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
search.index_name(doc.tenant_id), doc.kb_id): search.index_name(doc.tenant_id), doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
return get_json_result(data=True) return get_json_result(data=True)
@ -186,7 +187,7 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
return get_data_error_result(message="Index updating failure") return get_data_error_result(message="Index updating failure")
deleted_chunk_ids = req["chunk_ids"] deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids) chunk_number = len(deleted_chunk_ids)
@ -230,7 +231,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0) doc.id, doc.kb_id, c, 1, 0)
@ -265,7 +266,7 @@ def retrieval_test():
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e: if not e:
@ -281,7 +282,7 @@ def retrieval_test():
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size, ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top, similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")) doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
@ -293,7 +294,7 @@ def retrieval_test():
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR) code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)
@ -307,7 +308,7 @@ def knowledge_graph():
"doc_ids": [doc_id], "doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"] "knowledge_graph_kwd": ["graph", "mind_map"]
} }
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids) sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}} obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]: for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"] ty = sres.field[id]["knowledge_graph_kwd"]
@ -336,4 +337,3 @@ def knowledge_graph():
obj[ty] = content_json obj[ty] = content_json
return get_json_result(data=obj) return get_json_result(data=obj)

View File

@ -25,7 +25,7 @@ from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
@ -87,7 +87,7 @@ def get():
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.', data=False, message='Only owner of conversation authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
conv = conv.to_dict() conv = conv.to_dict()
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
@ -110,7 +110,7 @@ def rm():
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.', data=False, message='Only owner of conversation authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid) ConversationService.delete_by_id(cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -125,7 +125,7 @@ def list_convsersation():
if not DialogService.query(tenant_id=current_user.id, id=dialog_id): if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result( return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.', data=False, message='Only owner of dialog authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
convs = ConversationService.query( convs = ConversationService.query(
dialog_id=dialog_id, dialog_id=dialog_id,
order_by=ConversationService.model.create_time, order_by=ConversationService.model.create_time,
@ -297,6 +297,7 @@ def thumbup():
def ask_about(): def ask_about():
req = request.json req = request.json
uid = current_user.id uid = current_user.id
def stream(): def stream():
nonlocal req, uid nonlocal req, uid
try: try:
@ -329,7 +330,7 @@ def mindmap():
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12, ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False) 0.3, 0.3, aggs=False)
mindmap = MindMapExtractor(chat_mdl) mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output

View File

@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@ -175,7 +175,7 @@ def rm():
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.', data=False, message='Only owner of dialog authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value}) dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list) DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True) return get_json_result(data=True)

View File

@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService, doc_upload_and_parse from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.settings import RetCode, docStoreConn from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
@ -49,16 +49,16 @@ def upload():
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
@ -67,7 +67,7 @@ def upload():
err, _ = FileService.upload_document(kb, file_objs, current_user.id) err, _ = FileService.upload_document(kb, file_objs, current_user.id)
if err: if err:
return get_json_result( return get_json_result(
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR) data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True) return get_json_result(data=True)
@ -78,12 +78,12 @@ def web_crawl():
kb_id = request.form.get("kb_id") kb_id = request.form.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
name = request.form.get("name") name = request.form.get("name")
url = request.form.get("url") url = request.form.get("url")
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
raise LookupError("Can't find this knowledgebase!") raise LookupError("Can't find this knowledgebase!")
@ -145,7 +145,7 @@ def create():
kb_id = req["kb_id"] kb_id = req["kb_id"]
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
@ -179,7 +179,7 @@ def list_docs():
kb_id = request.args.get("kb_id") kb_id = request.args.get("kb_id")
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(
@ -188,7 +188,7 @@ def list_docs():
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))
@ -218,7 +218,7 @@ def docinfos():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
docs = DocumentService.get_by_ids(doc_ids) docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts())) return get_json_result(data=list(docs.dicts()))
@ -230,7 +230,7 @@ def thumbnails():
doc_ids = request.args.get("doc_ids").split(",") doc_ids = request.args.get("doc_ids").split(",")
if not doc_ids: if not doc_ids:
return get_json_result( return get_json_result(
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR) data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
docs = DocumentService.get_thumbnails(doc_ids) docs = DocumentService.get_thumbnails(doc_ids)
@ -253,13 +253,13 @@ def change_status():
return get_json_result( return get_json_result(
data=False, data=False,
message='"Status" must be either 0 or 1!', message='"Status" must be either 0 or 1!',
code=RetCode.ARGUMENT_ERROR) code=settings.RetCode.ARGUMENT_ERROR)
if not DocumentService.accessible(req["doc_id"], current_user.id): if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR) code=settings.RetCode.AUTHENTICATION_ERROR)
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
@ -276,7 +276,8 @@ def change_status():
message="Database error (Document update)!") message="Database error (Document update)!")
status = int(req["status"]) status = int(req["status"])
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
search.index_name(kb.tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@ -295,7 +296,7 @@ def rm():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
root_folder = FileService.get_root_folder(current_user.id) root_folder = FileService.get_root_folder(current_user.id)
@ -326,7 +327,7 @@ def rm():
errors += str(e) errors += str(e)
if errors: if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True) return get_json_result(data=True)
@ -341,7 +342,7 @@ def run():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
for id in req["doc_ids"]: for id in req["doc_ids"]:
@ -358,8 +359,8 @@ def run():
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
if not e: if not e:
return get_data_error_result(message="Document not found!") return get_data_error_result(message="Document not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value: if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
@ -383,7 +384,7 @@ def rename():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
@ -394,7 +395,7 @@ def rename():
return get_json_result( return get_json_result(
data=False, data=False,
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR) code=settings.RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
return get_data_error_result( return get_data_error_result(
@ -450,7 +451,7 @@ def change_parser():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
e, doc = DocumentService.get_by_id(req["doc_id"]) e, doc = DocumentService.get_by_id(req["doc_id"])
@ -483,8 +484,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"]) tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id: if not tenant_id:
return get_data_error_result(message="Tenant not found!") return get_data_error_result(message="Tenant not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
@ -509,13 +510,13 @@ def get_image(image_id):
def upload_and_parse(): def upload_and_parse():
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
@ -529,7 +530,7 @@ def parse():
if url: if url:
if not is_valid_url(url): if not is_valid_url(url):
return get_json_result( return get_json_result(
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR) data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
download_path = os.path.join(get_project_base_directory(), "logs/downloads") download_path = os.path.join(get_project_base_directory(), "logs/downloads")
os.makedirs(download_path, exist_ok=True) os.makedirs(download_path, exist_ok=True)
from selenium.webdriver import Chrome, ChromeOptions from selenium.webdriver import Chrome, ChromeOptions
@ -553,7 +554,7 @@ def parse():
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
txt = FileService.parse_docs(file_objs, current_user.id) txt = FileService.parse_docs(file_objs, current_user.id)

View File

@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.utils import get_uuid from api.utils import get_uuid
from api.db import FileType from api.db import FileType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import RetCode from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@ -100,7 +100,7 @@ def rm():
file_ids = req["file_ids"] file_ids = req["file_ids"]
if not file_ids: if not file_ids:
return get_json_result( return get_json_result(
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR) data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
try: try:
for file_id in file_ids: for file_id in file_ids:
informs = File2DocumentService.get_by_file_id(file_id) informs = File2DocumentService.get_by_file_id(file_id)

View File

@ -28,7 +28,7 @@ from api.utils import get_uuid
from api.db import FileType, FileSource from api.db import FileType, FileSource
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.settings import RetCode from api import settings
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type from api.utils.file_utils import filename_type
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
@ -46,13 +46,13 @@ def upload():
if 'file' not in request.files: if 'file' not in request.files:
return get_json_result( return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file') file_objs = request.files.getlist('file')
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == '': if file_obj.filename == '':
return get_json_result( return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
file_res = [] file_res = []
try: try:
for file_obj in file_objs: for file_obj in file_objs:
@ -134,7 +134,7 @@ def create():
try: try:
if not FileService.is_parent_folder_exist(pf_id): if not FileService.is_parent_folder_exist(pf_id):
return get_json_result( return get_json_result(
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR) data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
if FileService.query(name=req["name"], parent_id=pf_id): if FileService.query(name=req["name"], parent_id=pf_id):
return get_data_error_result( return get_data_error_result(
message="Duplicated folder name in the same folder.") message="Duplicated folder name in the same folder.")
@ -299,7 +299,7 @@ def rename():
return get_json_result( return get_json_result(
data=False, data=False,
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR) code=settings.RetCode.ARGUMENT_ERROR)
for file in FileService.query(name=req["name"], pf_id=file.parent_id): for file in FileService.query(name=req["name"], pf_id=file.parent_id):
if file.name == req["name"]: if file.name == req["name"]:
return get_data_error_result( return get_data_error_result(

View File

@ -26,9 +26,8 @@ from api.utils import get_uuid
from api.db import StatusEnum, FileSource from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
from api.settings import docStoreConn from api import settings
from rag.nlp import search from rag.nlp import search
@ -68,13 +67,13 @@ def update():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
if not KnowledgebaseService.query( if not KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]): created_by=current_user.id, id=req["kb_id"]):
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e: if not e:
@ -113,7 +112,7 @@ def detail():
else: else:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR) code=settings.RetCode.OPERATING_ERROR)
kb = KnowledgebaseService.get_detail(kb_id) kb = KnowledgebaseService.get_detail(kb_id)
if not kb: if not kb:
return get_data_error_result( return get_data_error_result(
@ -148,14 +147,14 @@ def rm():
return get_json_result( return get_json_result(
data=False, data=False,
message='No authorization.', message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR code=settings.RetCode.AUTHENTICATION_ERROR
) )
try: try:
kbs = KnowledgebaseService.query( kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]) created_by=current_user.id, id=req["kb_id"])
if not kbs: if not kbs:
return get_json_result( return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR) data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["kb_id"]): for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id): if not DocumentService.remove_document(doc, kbs[0].tenant_id):
@ -170,7 +169,7 @@ def rm():
message="Database error (Knowledgebase removal)!") message="Database error (Knowledgebase removal)!")
tenants = UserTenantService.query(user_id=current_user.id) tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants: for tenant in tenants:
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"]) settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -19,7 +19,7 @@ import json
from flask import request from flask import request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from api.settings import LIGHTEN from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db import StatusEnum, LLMType from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM from api.db.db_models import TenantLLM
@ -333,7 +333,7 @@ def my_llms():
@login_required @login_required
def list_app(): def list_app():
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"] self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else [] weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type") model_type = request.args.get("model_type")
try: try:
objs = TenantLLMService.query(tenant_id=current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id)

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
# #
from flask import request from flask import request
from api.settings import RetCode from api import settings
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
@ -44,7 +44,7 @@ def create(tenant_id):
kbs = KnowledgebaseService.get_by_ids(ids) kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs])) embd_count = list(set([kb.embd_id for kb in kbs]))
if len(embd_count) != 1: if len(embd_count) != 1:
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR) return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids req["kb_ids"] = ids
# llm # llm
llm = req.get("llm") llm = req.get("llm")
@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
if len(embd_count) != 1 : if len(embd_count) != 1 :
return get_result( return get_result(
message='Datasets use different embedding models."', message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR) code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids req["kb_ids"] = ids
llm = req.get("llm") llm = req.get("llm")
if llm: if llm:

View File

@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMService from api.db.services.llm_service import TenantLLMService, LLMService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import ( from api.utils.api_utils import (
get_result, get_result,
@ -255,7 +255,7 @@ def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id) File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id): if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(message="Delete dataset error.(Database error)") return get_error_data_result(message="Delete dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS) return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) @manager.route("/datasets/<dataset_id>", methods=["PUT"])
@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
) )
if not KnowledgebaseService.update_by_id(kb.id, req): if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)") return get_error_data_result(message="Update dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS) return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets", methods=["GET"]) @manager.route("/datasets", methods=["GET"])

View File

@ -18,7 +18,7 @@ from flask import request, jsonify
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler, kg_retrievaler, RetCode from api import settings
from api.utils.api_utils import validate_request, build_error_result, apikey_required from api.utils.api_utils import validate_request, build_error_result, apikey_required
@ -37,14 +37,14 @@ def retrieval(tenant_id):
e, kb = KnowledgebaseService.get_by_id(kb_id) e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e: if not e:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
if kb.tenant_id != tenant_id: if kb.tenant_id != tenant_id:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval( ranks = retr.retrieval(
question, question,
embd_mdl, embd_mdl,
@ -72,6 +72,6 @@ def retrieval(tenant_id):
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return build_error_result( return build_error_result(
message='No chunk found! Check the chunk status please!', message='No chunk found! Check the chunk status please!',
code=RetCode.NOT_FOUND code=settings.RetCode.NOT_FOUND
) )
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR) return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)

View File

@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService from api.db.services.llm_service import TenantLLMService
from api.settings import kg_retrievaler from api import settings
import hashlib import hashlib
import re import re
from api.utils.api_utils import token_required from api.utils.api_utils import token_required
@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import RetCode, retrievaler from api import settings
from api.utils.api_utils import construct_json_result, get_parser_config from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search from rag.nlp import search
from rag.utils import rmSpace from rag.utils import rmSpace
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL from rag.utils.storage_factory import STORAGE_IMPL
import os import os
@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
""" """
if "file" not in request.files: if "file" not in request.files:
return get_error_data_result( return get_error_data_result(
message="No file part!", code=RetCode.ARGUMENT_ERROR message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
) )
file_objs = request.files.getlist("file") file_objs = request.files.getlist("file")
for file_obj in file_objs: for file_obj in file_objs:
if file_obj.filename == "": if file_obj.filename == "":
return get_result( return get_result(
message="No file selected!", code=RetCode.ARGUMENT_ERROR message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
) )
# total size # total size
total_size = 0 total_size = 0
@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
if total_size > MAX_TOTAL_FILE_SIZE: if total_size > MAX_TOTAL_FILE_SIZE:
return get_result( return get_result(
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)", message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
code=RetCode.ARGUMENT_ERROR, code=settings.RetCode.ARGUMENT_ERROR,
) )
e, kb = KnowledgebaseService.get_by_id(dataset_id) e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e: if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!") raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id) err, files = FileService.upload_document(kb, file_objs, tenant_id)
if err: if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR) return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
# rename key's name # rename key's name
renamed_doc_list = [] renamed_doc_list = []
for file in files: for file in files:
@ -226,7 +225,7 @@ def update_doc(tenant_id, dataset_id, document_id):
): ):
return get_result( return get_result(
message="The extension of file can't be changed", message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR, code=settings.RetCode.ARGUMENT_ERROR,
) )
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]: if d.name == req["name"]:
@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
) )
if not e: if not e:
return get_error_data_result(message="Document not found!") return get_error_data_result(message="Document not found!")
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()
@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
file_stream = STORAGE_IMPL.get(doc_id, doc_location) file_stream = STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream: if not file_stream:
return construct_json_result( return construct_json_result(
message="This file is empty.", code=RetCode.DATA_ERROR message="This file is empty.", code=settings.RetCode.DATA_ERROR
) )
file = BytesIO(file_stream) file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type # Use send_file with a proper filename and MIME type
@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
errors += str(e) errors += str(e)
if errors: if errors:
return get_result(message=errors, code=RetCode.SERVER_ERROR) return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
return get_result() return get_result()
@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
info["chunk_num"] = 0 info["chunk_num"] = 0
info["token_num"] = 0 info["token_num"] = 0
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id) settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id]) TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id) e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict() doc = doc.to_dict()
@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
) )
info = {"run": "2", "progress": 0, "chunk_num": 0} info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info) DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()
@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc} res = {"total": 0, "chunks": [], "doc": renamed_doc}
origin_chunks = [] origin_chunks = []
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
highlight=True)
res["total"] = sres.total res["total"] = sres.total
sign = 0 sign = 0
for id in sres.ids: for id in sres.ids:
@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"]]) v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1] v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0) DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys # rename keys
@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
condition = {"doc_id": document_id} condition = {"doc_id": document_id}
if "chunk_ids" in req: if "chunk_ids" in req:
condition["id"] = req["chunk_ids"] condition["id"] = req["chunk_ids"]
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0: if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]): if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema: schema:
type: object type: object
""" """
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None: if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}") return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]]) v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist() d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result() return get_result()
@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
if len(embd_nms) != 1: if len(embd_nms) != 1:
return get_result( return get_result(
message='Datasets use different embedding models."', message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR, code=settings.RetCode.AUTHENTICATION_ERROR,
) )
if "question" not in req: if "question" not in req:
return get_error_data_result("`question` is required.") return get_error_data_result("`question` is required.")
@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval( ranks = retr.retrieval(
question, question,
embd_mdl, embd_mdl,
@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_result( return get_result(
message="No chunk found! Check the chunk status please!", message="No chunk found! Check the chunk status please!",
code=RetCode.DATA_ERROR, code=settings.RetCode.DATA_ERROR,
) )
return server_error_response(e) return server_error_response(e)

View File

@ -22,7 +22,7 @@ from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService from api.db.services.api_service import APITokenService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.settings import DATABASE_TYPE from api import settings
from api.utils import current_timestamp, datetime_format from api.utils import current_timestamp, datetime_format
from api.utils.api_utils import ( from api.utils.api_utils import (
get_json_result, get_json_result,
@ -31,7 +31,6 @@ from api.utils.api_utils import (
generate_confirmation_token, generate_confirmation_token,
) )
from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
from timeit import default_timer as timer from timeit import default_timer as timer
@ -98,7 +97,7 @@ def status():
res = {} res = {}
st = timer() st = timer()
try: try:
res["doc_store"] = docStoreConn.health() res["doc_store"] = settings.docStoreConn.health()
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0) res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e: except Exception as e:
res["doc_store"] = { res["doc_store"] = {
@ -128,13 +127,13 @@ def status():
try: try:
KnowledgebaseService.get_by_id("x") KnowledgebaseService.get_by_id("x")
res["database"] = { res["database"] = {
"database": DATABASE_TYPE.lower(), "database": settings.DATABASE_TYPE.lower(),
"status": "green", "status": "green",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0), "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
} }
except Exception as e: except Exception as e:
res["database"] = { res["database"] = {
"database": DATABASE_TYPE.lower(), "database": settings.DATABASE_TYPE.lower(),
"status": "red", "status": "red",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0), "elapsed": "{:.1f}".format((timer() - st) * 1000.0),
"error": str(e), "error": str(e),

View File

@ -38,20 +38,7 @@ from api.utils import (
datetime_format, datetime_format,
) )
from api.db import UserTenantRole, FileType from api.db import UserTenantRole, FileType
from api.settings import ( from api import settings
RetCode,
GITHUB_OAUTH,
FEISHU_OAUTH,
CHAT_MDL,
EMBEDDING_MDL,
ASR_MDL,
IMAGE2TEXT_MDL,
PARSERS,
API_KEY,
LLM_FACTORY,
LLM_BASE_URL,
RERANK_MDL,
)
from api.db.services.user_service import UserService, TenantService, UserTenantService from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result, construct_response from api.utils.api_utils import get_json_result, construct_response
@ -90,7 +77,7 @@ def login():
""" """
if not request.json: if not request.json:
return get_json_result( return get_json_result(
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!" data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
) )
email = request.json.get("email", "") email = request.json.get("email", "")
@ -98,7 +85,7 @@ def login():
if not users: if not users:
return get_json_result( return get_json_result(
data=False, data=False,
code=RetCode.AUTHENTICATION_ERROR, code=settings.RetCode.AUTHENTICATION_ERROR,
message=f"Email: {email} is not registered!", message=f"Email: {email} is not registered!",
) )
@ -107,7 +94,7 @@ def login():
password = decrypt(password) password = decrypt(password)
except BaseException: except BaseException:
return get_json_result( return get_json_result(
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password" data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
) )
user = UserService.query_user(email, password) user = UserService.query_user(email, password)
@ -123,7 +110,7 @@ def login():
else: else:
return get_json_result( return get_json_result(
data=False, data=False,
code=RetCode.AUTHENTICATION_ERROR, code=settings.RetCode.AUTHENTICATION_ERROR,
message="Email and password do not match!", message="Email and password do not match!",
) )
@ -150,10 +137,10 @@ def github_callback():
import requests import requests
res = requests.post( res = requests.post(
GITHUB_OAUTH.get("url"), settings.GITHUB_OAUTH.get("url"),
data={ data={
"client_id": GITHUB_OAUTH.get("client_id"), "client_id": settings.GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"), "client_secret": settings.GITHUB_OAUTH.get("secret_key"),
"code": request.args.get("code"), "code": request.args.get("code"),
}, },
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
@ -235,11 +222,11 @@ def feishu_callback():
import requests import requests
app_access_token_res = requests.post( app_access_token_res = requests.post(
FEISHU_OAUTH.get("app_access_token_url"), settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
"app_id": FEISHU_OAUTH.get("app_id"), "app_id": settings.FEISHU_OAUTH.get("app_id"),
"app_secret": FEISHU_OAUTH.get("app_secret"), "app_secret": settings.FEISHU_OAUTH.get("app_secret"),
} }
), ),
headers={"Content-Type": "application/json; charset=utf-8"}, headers={"Content-Type": "application/json; charset=utf-8"},
@ -249,10 +236,10 @@ def feishu_callback():
return redirect("/?error=%s" % app_access_token_res) return redirect("/?error=%s" % app_access_token_res)
res = requests.post( res = requests.post(
FEISHU_OAUTH.get("user_access_token_url"), settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps( data=json.dumps(
{ {
"grant_type": FEISHU_OAUTH.get("grant_type"), "grant_type": settings.FEISHU_OAUTH.get("grant_type"),
"code": request.args.get("code"), "code": request.args.get("code"),
} }
), ),
@ -409,7 +396,7 @@ def setting_user():
): ):
return get_json_result( return get_json_result(
data=False, data=False,
code=RetCode.AUTHENTICATION_ERROR, code=settings.RetCode.AUTHENTICATION_ERROR,
message="Password error!", message="Password error!",
) )
@ -438,7 +425,7 @@ def setting_user():
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
return get_json_result( return get_json_result(
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
) )
@ -497,12 +484,12 @@ def user_register(user_id, user):
tenant = { tenant = {
"id": user_id, "id": user_id,
"name": user["nickname"] + "s Kingdom", "name": user["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": EMBEDDING_MDL, "embd_id": settings.EMBEDDING_MDL,
"asr_id": ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": IMAGE2TEXT_MDL, "img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": RERANK_MDL, "rerank_id": settings.RERANK_MDL,
} }
usr_tenant = { usr_tenant = {
"tenant_id": user_id, "tenant_id": user_id,
@ -522,15 +509,15 @@ def user_register(user_id, user):
"location": "", "location": "",
} }
tenant_llm = [] tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY): for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append( tenant_llm.append(
{ {
"tenant_id": user_id, "tenant_id": user_id,
"llm_factory": LLM_FACTORY, "llm_factory": settings.LLM_FACTORY,
"llm_name": llm.llm_name, "llm_name": llm.llm_name,
"model_type": llm.model_type, "model_type": llm.model_type,
"api_key": API_KEY, "api_key": settings.API_KEY,
"api_base": LLM_BASE_URL, "api_base": settings.LLM_BASE_URL,
} }
) )
@ -582,7 +569,7 @@ def user_add():
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Invalid email address: {email_address}!", message=f"Invalid email address: {email_address}!",
code=RetCode.OPERATING_ERROR, code=settings.RetCode.OPERATING_ERROR,
) )
# Check if the email address is already used # Check if the email address is already used
@ -590,7 +577,7 @@ def user_add():
return get_json_result( return get_json_result(
data=False, data=False,
message=f"Email: {email_address} has already registered!", message=f"Email: {email_address} has already registered!",
code=RetCode.OPERATING_ERROR, code=settings.RetCode.OPERATING_ERROR,
) )
# Construct user info data # Construct user info data
@ -625,7 +612,7 @@ def user_add():
return get_json_result( return get_json_result(
data=False, data=False,
message=f"User registration failure, error: {str(e)}", message=f"User registration failure, error: {str(e)}",
code=RetCode.EXCEPTION_ERROR, code=settings.RetCode.EXCEPTION_ERROR,
) )

View File

@ -31,7 +31,7 @@ from peewee import (
) )
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType from api.db import SerializedType, ParserType
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE from api import settings
from api import utils from api import utils
def singleton(cls, *args, **kw): def singleton(cls, *args, **kw):
@ -62,7 +62,7 @@ class TextFieldType(Enum):
class LongTextField(TextField): class LongTextField(TextField):
field_type = TextFieldType[DATABASE_TYPE.upper()].value field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
class JSONField(LongTextField): class JSONField(LongTextField):
@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
@singleton @singleton
class BaseDataBase: class BaseDataBase:
def __init__(self): def __init__(self):
database_config = DATABASE.copy() database_config = settings.DATABASE.copy()
db_name = database_config.pop("name") db_name = database_config.pop("name")
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config) self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info('init database on cluster mode successfully') logging.info('init database on cluster mode successfully')
class PostgresDatabaseLock: class PostgresDatabaseLock:
@ -385,7 +385,7 @@ class DatabaseLock(Enum):
DB = BaseDataBase().database_connection DB = BaseDataBase().database_connection
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
def close_connection(): def close_connection():
@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
return self.email return self.email
def get_id(self): def get_id(self):
jwt = Serializer(secret_key=SECRET_KEY) jwt = Serializer(secret_key=settings.SECRET_KEY)
return jwt.dumps(str(self.access_token)) return jwt.dumps(str(self.access_token))
class Meta: class Meta:
@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
def migrate_db(): def migrate_db():
with DB.transaction(): with DB.transaction():
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
try: try:
migrate( migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="", migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",

View File

@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL from api import settings
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
@ -51,11 +51,11 @@ def init_superuser():
tenant = { tenant = {
"id": user_info["id"], "id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom", "name": user_info["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL, "llm_id": settings.CHAT_MDL,
"embd_id": EMBEDDING_MDL, "embd_id": settings.EMBEDDING_MDL,
"asr_id": ASR_MDL, "asr_id": settings.ASR_MDL,
"parser_ids": PARSERS, "parser_ids": settings.PARSERS,
"img2txt_id": IMAGE2TEXT_MDL "img2txt_id": settings.IMAGE2TEXT_MDL
} }
usr_tenant = { usr_tenant = {
"tenant_id": user_info["id"], "tenant_id": user_info["id"],
@ -64,10 +64,11 @@ def init_superuser():
"role": UserTenantRole.OWNER "role": UserTenantRole.OWNER
} }
tenant_llm = [] tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY): for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append( tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type, {"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
"api_key": API_KEY, "api_base": LLM_BASE_URL}) "model_type": llm.model_type,
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
if not UserService.save(**user_info): if not UserService.save(**user_info):
logging.error("can't init admin.") logging.error("can't init admin.")

View File

@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import retrievaler, kg_retrievaler from api import settings
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder from rag.utils import rmSpace, num_tokens_from_string, encoder
@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"][-3:] questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
logging.debug(f"{question} get SQL(refined): {sql}") logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1 tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql return settings.retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table() tbl, sql = get_table()
if tbl is None: if tbl is None:
@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
embd_nms = list(set([kb.embd_id for kb in kbs])) embd_nms = list(set([kb.embd_id for kb in kbs]))
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs]) is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0]) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)

View File

@ -26,7 +26,7 @@ from io import BytesIO
from peewee import fn from peewee import fn
from api.db.db_utils import bulk_insert_into_db from api.db.db_utils import bulk_insert_into_db
from api.settings import docStoreConn from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME from rag.settings import SVR_QUEUE_NAME
@ -108,7 +108,7 @@ class DocumentService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def remove_document(cls, doc, tenant_id): def remove_document(cls, doc, tenant_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id) cls.clear_chunk_num(doc.id)
return cls.delete_by_id(doc.id) return cls.delete_by_id(doc.id)
@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
if try_create_idx: if try_create_idx:
if not docStoreConn.indexExist(idxnm, kb_id): if not settings.docStoreConn.indexExist(idxnm, kb_id):
docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False try_create_idx = False
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num( DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -33,12 +33,10 @@ import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from werkzeug.serving import run_simple from werkzeug.serving import run_simple
from api import settings
from api.apps import app from api.apps import app
from api.db.runtime_config import RuntimeConfig from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.settings import (
HOST, HTTP_PORT
)
from api import utils from api import utils
from api.db.db_models import init_database_tables as init_web_db from api.db.db_models import init_database_tables as init_web_db
@ -72,6 +70,7 @@ if __name__ == '__main__':
f'project base: {utils.file_utils.get_project_base_directory()}' f'project base: {utils.file_utils.get_project_base_directory()}'
) )
show_configs() show_configs()
settings.init_settings()
# init db # init db
init_web_db() init_web_db()
@ -96,7 +95,7 @@ if __name__ == '__main__':
logging.info("run on debug mode") logging.info("run on debug mode")
RuntimeConfig.init_env() RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT) RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
thread = ThreadPoolExecutor(max_workers=1) thread = ThreadPoolExecutor(max_workers=1)
thread.submit(update_progress) thread.submit(update_progress)
@ -105,8 +104,8 @@ if __name__ == '__main__':
try: try:
logging.info("RAGFlow HTTP server start...") logging.info("RAGFlow HTTP server start...")
run_simple( run_simple(
hostname=HOST, hostname=settings.HOST_IP,
port=HTTP_PORT, port=settings.HOST_PORT,
application=app, application=app,
threaded=True, threaded=True,
use_reloader=RuntimeConfig.DEBUG, use_reloader=RuntimeConfig.DEBUG,

View File

@ -30,12 +30,46 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
REQUEST_WAIT_SEC = 2 REQUEST_WAIT_SEC = 2
REQUEST_MAX_WAIT_SEC = 300 REQUEST_MAX_WAIT_SEC = 300
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
EMBEDDING_MDL = ""
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
API_KEY = None
PARSERS = None
HOST_IP = None
HOST_PORT = None
SECRET_KEY = None
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication
AUTHENTICATION_CONF = None
# client
CLIENT_AUTHENTICATION = None
HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None
DOC_ENGINE = None
docStoreConn = None
retrievaler = None
kg_retrievaler = None
def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL
LLM = get_base_config("user_default_llm", {}) LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen") LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url") LLM_BASE_URL = LLM.get("base_url")
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = "" global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
if not LIGHTEN: if not LIGHTEN:
default_llm = { default_llm = {
"Tongyi-Qianwen": { "Tongyi-Qianwen": {
@ -102,21 +136,20 @@ if not LIGHTEN:
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI" EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI" RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key", "") API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get( PARSERS = LLM.get(
"parsers", "parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email") "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1") HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port") HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config( SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME, RAG_FLOW_SERVICE_NAME,
{}).get("secret_key", str(date.today())) {}).get("secret_key", str(date.today()))
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql') global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication # authentication
AUTHENTICATION_CONF = get_base_config("authentication", {}) AUTHENTICATION_CONF = get_base_config("authentication", {})
@ -128,6 +161,7 @@ HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github") GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu") FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch") DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
if DOC_ENGINE == "elasticsearch": if DOC_ENGINE == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection() docStoreConn = rag.utils.es_conn.ESConnection()
@ -139,6 +173,15 @@ else:
retrievaler = search.Dealer(docStoreConn) retrievaler = search.Dealer(docStoreConn)
kg_retrievaler = kg_search.KGSearch(docStoreConn) kg_retrievaler = kg_search.KGSearch(docStoreConn)
def get_host_ip():
global HOST_IP
return HOST_IP
def get_host_port():
global HOST_PORT
return HOST_PORT
class CustomEnum(Enum): class CustomEnum(Enum):
@classmethod @classmethod

View File

@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
from werkzeug.http import HTTP_STATUS_CODES from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.settings import ( from api import settings
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY from api import settings
)
from api.settings import RetCode
from api.utils import CustomJSONEncoder, get_uuid from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps from api.utils import json_dumps
@ -59,13 +57,13 @@ def request(**kwargs):
{}).items()} {}).items()}
prepped = requests.Request(**kwargs).prepare() prepped = requests.Request(**kwargs).prepare()
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY: if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
timestamp = str(round(time() * 1000)) timestamp = str(round(time() * 1000))
nonce = str(uuid1()) nonce = str(uuid1())
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([ signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
timestamp.encode('ascii'), timestamp.encode('ascii'),
nonce.encode('ascii'), nonce.encode('ascii'),
HTTP_APP_KEY.encode('ascii'), settings.HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'), prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'', prepped.body if kwargs.get('json') else b'',
urlencode( urlencode(
@ -79,7 +77,7 @@ def request(**kwargs):
prepped.headers.update({ prepped.headers.update({
'TIMESTAMP': timestamp, 'TIMESTAMP': timestamp,
'NONCE': nonce, 'NONCE': nonce,
'APP-KEY': HTTP_APP_KEY, 'APP-KEY': settings.HTTP_APP_KEY,
'SIGNATURE': signature, 'SIGNATURE': signature,
}) })
@ -89,7 +87,7 @@ def request(**kwargs):
def get_exponential_backoff_interval(retries, full_jitter=False): def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time.""" """Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0 # Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries)) countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
# Full jitter according to # Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ # https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter: if full_jitter:
@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
return max(0, countdown) return max(0, countdown)
def get_data_error_result(code=RetCode.DATA_ERROR, def get_data_error_result(code=settings.RetCode.DATA_ERROR,
message='Sorry! Data missing!'): message='Sorry! Data missing!'):
import re import re
result_dict = { result_dict = {
@ -126,8 +124,8 @@ def server_error_response(e):
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return get_json_result( return get_json_result(
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def error_response(response_code, message=None): def error_response(response_code, message=None):
@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
error_string += "required argument values: {}".format( error_string += "required argument values: {}".format(
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) ",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result( return get_json_result(
code=RetCode.ARGUMENT_ERROR, message=error_string) code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
return func(*_args, **_kwargs) return func(*_args, **_kwargs)
return decorated_function return decorated_function
@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
return send_file(f, as_attachment=True, attachment_filename=filename) return send_file(f, as_attachment=True, attachment_filename=filename)
def get_json_result(code=RetCode.SUCCESS, message='success', data=None): def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
response = {"code": code, "message": message, "data": data} response = {"code": code, "message": message, "data": data}
return jsonify(response) return jsonify(response)
@ -204,7 +202,7 @@ def apikey_required(func):
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return build_error_result( return build_error_result(
message='API-KEY is invalid!', code=RetCode.FORBIDDEN message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
) )
kwargs['tenant_id'] = objs[0].tenant_id kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
@ -212,14 +210,14 @@ def apikey_required(func):
return decorated_function return decorated_function
def build_error_result(code=RetCode.FORBIDDEN, message='success'): def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
response = {"code": code, "message": message} response = {"code": code, "message": message}
response = jsonify(response) response = jsonify(response)
response.status_code = code response.status_code = code
return response return response
def construct_response(code=RetCode.SUCCESS, def construct_response(code=settings.RetCode.SUCCESS,
message='success', data=None, auth=None): message='success', data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data} result_dict = {"code": code, "message": message, "data": data}
response_dict = {} response_dict = {}
@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
return response return response
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'): def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
import re import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)} result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
response = {} response = {}
@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
return jsonify(response) return jsonify(response)
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None): def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
if data is None: if data is None:
return jsonify({"code": code, "message": message}) return jsonify({"code": code, "message": message})
else: else:
@ -262,12 +260,12 @@ def construct_error_response(e):
logging.exception(e) logging.exception(e)
try: try:
if e.code == 401: if e.code == 401:
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e)) return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
except BaseException: except BaseException:
pass pass
if len(e.args) > 1: if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1]) return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e)) return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def token_required(func): def token_required(func):
@ -280,7 +278,7 @@ def token_required(func):
objs = APIToken.query(token=token) objs = APIToken.query(token=token)
if not objs: if not objs:
return get_json_result( return get_json_result(
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
) )
kwargs['tenant_id'] = objs[0].tenant_id kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs) return func(*args, **kwargs)
@ -288,7 +286,7 @@ def token_required(func):
return decorated_function return decorated_function
def get_result(code=RetCode.SUCCESS, message="", data=None): def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
if code == 0: if code == 0:
if data is not None: if data is not None:
response = {"code": code, "data": data} response = {"code": code, "data": data}
@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
return jsonify(response) return jsonify(response)
def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR, def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
): ):
import re import re
result_dict = { result_dict = {

View File

@ -24,7 +24,7 @@ import numpy as np
from timeit import default_timer as timer from timeit import default_timer as timer
from pypdf import PdfReader as pdf2_read from pypdf import PdfReader as pdf2_read
from api.settings import LIGHTEN from api import settings
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
@ -41,7 +41,7 @@ class RAGFlowPdfParser:
self.tbl_det = TableStructureRecognizer() self.tbl_det = TableStructureRecognizer()
self.updown_cnt_mdl = xgb.Booster() self.updown_cnt_mdl = xgb.Booster()
if not LIGHTEN: if not settings.LIGHTEN:
try: try:
import torch import torch
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -252,13 +252,13 @@ if __name__ == "__main__":
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])] docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = { info = {
"input_text": docs, "input_text": docs,
"entity_specs": "organization, person", "entity_specs": "organization, person",

View File

@ -30,14 +30,14 @@ if __name__ == "__main__":
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id) kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT)) ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in docs = [d["content_with_weight"] for d in
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])] settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
graph = ex(docs) graph = ex(docs)
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT)) er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))

View File

@ -23,7 +23,7 @@ from collections import defaultdict
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler, docStoreConn from api import settings
from api.utils import get_uuid from api.utils import get_uuid
from rag.nlp import tokenize, search from rag.nlp import tokenize, search
from ranx import evaluate from ranx import evaluate
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict) run = defaultdict(dict)
query_list = list(qrels.keys()) query_list = list(qrels.keys())
for query in query_list: for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30, ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight) 0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0: if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}") print(f"deleted query: {query}")
@ -81,9 +81,9 @@ class Benchmark:
def init_index(self, vector_size: int): def init_index(self, vector_size: int):
if self.initialized_index: if self.initialized_index:
return return
if docStoreConn.indexExist(self.index_name, self.kb_id): if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
docStoreConn.deleteIdx(self.index_name, self.kb_id) settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size) settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True self.initialized_index = True
def ms_marco_index(self, file_path, index_name): def ms_marco_index(self, file_path, index_name):
@ -118,13 +118,13 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id) settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = [] docs = []
if docs: if docs:
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id) settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts return qrels, texts
def trivia_qa_index(self, file_path, index_name): def trivia_qa_index(self, file_path, index_name):
@ -159,12 +159,12 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs,self.index_name) settings.docStoreConn.insert(docs,self.index_name)
docs = [] docs = []
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name) settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name): def miracl_index(self, file_path, corpus_path, index_name):
@ -214,12 +214,12 @@ class Benchmark:
docs_count += len(docs) docs_count += len(docs)
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name) settings.docStoreConn.insert(docs, self.index_name)
docs = [] docs = []
docs, vector_size = self.embedding(docs) docs, vector_size = self.embedding(docs)
self.init_index(vector_size) self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name) settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path): def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -28,7 +28,7 @@ from openai import OpenAI
import numpy as np import numpy as np
import asyncio import asyncio
from api.settings import LIGHTEN from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai import google.generativeai as genai
@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
^_- ^_-
""" """
if not LIGHTEN and not DefaultEmbedding._model: if not settings.LIGHTEN and not DefaultEmbedding._model:
with DefaultEmbedding._model_lock: with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
@ -248,7 +248,7 @@ class FastEmbed(Base):
threads: Optional[int] = None, threads: Optional[int] = None,
**kwargs, **kwargs,
): ):
if not LIGHTEN and not FastEmbed._model: if not settings.LIGHTEN and not FastEmbed._model:
from fastembed import TextEmbedding from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
_client = None _client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not LIGHTEN and not YoudaoEmbed._client: if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing from BCEmbedding import EmbeddingModel as qanthing
try: try:
logging.info("LOADING BCE...") logging.info("LOADING BCE...")

View File

@ -23,7 +23,7 @@ import os
from abc import ABC from abc import ABC
import numpy as np import numpy as np
from api.settings import LIGHTEN from api import settings
from api.utils.file_utils import get_home_cache_dir from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate from rag.utils import num_tokens_from_string, truncate
import json import json
@ -57,7 +57,7 @@ class DefaultRerank(Base):
^_- ^_-
""" """
if not LIGHTEN and not DefaultRerank._model: if not settings.LIGHTEN and not DefaultRerank._model:
import torch import torch
from FlagEmbedding import FlagReranker from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock: with DefaultRerank._model_lock:
@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
_model_lock = threading.Lock() _model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs): def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not LIGHTEN and not YoudaoRerank._model: if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock: with YoudaoRerank._model_lock:
if not YoudaoRerank._model: if not YoudaoRerank._model:

View File

@ -16,6 +16,7 @@
import logging import logging
import sys import sys
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
initRootLogger(f"task_executor_{CONSUMER_NO}") initRootLogger(f"task_executor_{CONSUMER_NO}")
for module in ["pdfminer"]: for module in ["pdfminer"]:
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler, docStoreConn from api import settings
from api.db.db_models import close_connection from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
from rag.nlp import search, rag_tokenizer from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
@ -88,6 +90,7 @@ PENDING_TASKS = 0
HEAD_CREATED_AT = "" HEAD_CREATED_AT = ""
HEAD_DETAIL = "" HEAD_DETAIL = ""
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD global PAYLOAD
if prog is not None and prog < 0: if prog is not None and prog < 0:
@ -171,7 +174,8 @@ def build(row):
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError: except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.") callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"])) logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
return return
except Exception as e: except Exception as e:
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
@ -226,7 +230,8 @@ def build(row):
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue()) STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st el += timer() - st
except Exception: except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"])) logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"]) d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"] del d["image"]
@ -262,7 +267,7 @@ def build(row):
def init_kb(row, vector_size: int): def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"]) idxnm = search.index_name(row["tenant_id"])
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size) return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
def embedding(docs, mdl, parser_config=None, callback=None): def embedding(docs, mdl, parser_config=None, callback=None):
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vector_size = len(vts[0]) vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size vctr_nm = "q_%d_vec" % vector_size
chunks = [] chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]): for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor( raptor = Raptor(
@ -384,7 +390,8 @@ def main():
# TODO: exception handler # TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ") ## set_progress(r["did"], -1, "ERROR: ")
callback( callback(
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st) msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
timer() - st)
) )
st = timer() st = timer()
try: try:
@ -403,18 +410,18 @@ def main():
es_r = "" es_r = ""
es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size): for b in range(0, len(cks), es_bulk_size):
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"]) es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0: if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
logging.error('Insert chunk error: ' + str(es_r)) logging.error('Insert chunk error: ' + str(es_r))
else: else:
if TaskService.do_cancel(r["id"]): if TaskService.do_cancel(r["id"]):
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
continue continue
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st)) callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
callback(1., "Done!") callback(1., "Done!")