Move settings initialization after module init phase (#3438)

### What problem does this PR solve?

1. Module init won't connect database any more.
2. Config in settings need to be used with settings.CONFIG_NAME

### Type of change

- [x] Refactoring

Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai 2024-11-15 17:30:56 +08:00 committed by GitHub
parent ac033b62cf
commit 1e90a1bf36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 452 additions and 411 deletions

View File

@ -19,7 +19,7 @@ import pandas as pd
from api.db import LLMType
from api.db.services.dialog_service import message_fit_in
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
@ -63,14 +63,16 @@ class Generate(ComponentBase):
component_name = "Generate"
def get_dependent_components(self):
cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
cpnts = [para["component_id"] for para in self._param.parameters if
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
return cpnts
def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True)
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
@ -127,12 +129,14 @@ class Generate(ComponentBase):
else:
if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out)
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
kwargs[para["key"]] = " - " + "\n - ".join(
[o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else: retrieval_res = pd.DataFrame([])
else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)

View File

@ -21,7 +21,7 @@ import pandas as pd
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
1, self._param.top_n,
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
aggs=False, rerank_mdl=rerank_mdl)

View File

@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
from flask_session import Session
from flask_login import LoginManager
from api.settings import SECRET_KEY
from api.settings import API_VERSION
from api import settings
from api.utils.api_utils import server_error_response
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
app.json_encoder = CustomJSONEncoder
app.errorhandler(Exception)(server_error_response)
## convince for dev and debug
# app.config["LOGIN_DISABLED"] = True
app.config["SESSION_PERMANENT"] = False
@ -121,7 +119,7 @@ def register_page(page_path):
spec.loader.exec_module(page)
page_name = getattr(page, "page_name", page_name)
url_prefix = (
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
)
app.register_blueprint(page.manager, url_prefix=url_prefix)
@ -141,7 +139,7 @@ client_urls_prefix = [
@login_manager.request_loader
def load_user(web_request):
jwt = Serializer(secret_key=SECRET_KEY)
jwt = Serializer(secret_key=settings.SECRET_KEY)
authorization = web_request.headers.get("Authorization")
if authorization:
try:

View File

@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import queue_tasks, TaskService
from api.db.services.user_service import UserTenantService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils import get_uuid, current_timestamp, datetime_format
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
generate_confirmation_token
@ -141,7 +141,7 @@ def set_conversation():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
try:
if objs[0].source == "agent":
@ -183,7 +183,7 @@ def completion():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
@ -347,7 +347,7 @@ def get(conversation_id):
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
try:
e, conv = API4ConversationService.get_by_id(conversation_id)
@ -357,7 +357,7 @@ def get(conversation_id):
conv = conv.to_dict()
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
for referenct_i in conv['reference']:
if referenct_i is None or len(referenct_i) == 0:
@ -378,7 +378,7 @@ def upload():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
kb_name = request.form.get("kb_name").strip()
tenant_id = objs[0].tenant_id
@ -394,12 +394,12 @@ def upload():
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file = request.files['file']
if file.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
root_folder = FileService.get_root_folder(tenant_id)
pf_id = root_folder["id"]
@ -490,17 +490,17 @@ def upload_parse():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
return get_json_result(data=doc_ids)
@ -513,7 +513,7 @@ def list_chunks():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
@ -531,7 +531,7 @@ def list_chunks():
)
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
res = [
{
"content": res_item["content_with_weight"],
@ -553,7 +553,7 @@ def list_kb_docs():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
tenant_id = objs[0].tenant_id
@ -585,6 +585,7 @@ def list_kb_docs():
except Exception as e:
return server_error_response(e)
@manager.route('/document/infos', methods=['POST'])
@validate_request("doc_ids")
def docinfos():
@ -592,7 +593,7 @@ def docinfos():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
doc_ids = req["doc_ids"]
docs = DocumentService.get_by_ids(doc_ids)
@ -606,7 +607,7 @@ def document_rm():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
tenant_id = objs[0].tenant_id
req = request.json
@ -653,7 +654,7 @@ def document_rm():
errors += str(e)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True)
@ -668,7 +669,7 @@ def completion_faq():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
if not e:
@ -805,7 +806,7 @@ def retrieval():
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
req = request.json
kb_ids = req.get("kb_id", [])
@ -822,7 +823,8 @@ def retrieval():
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
return get_json_result(
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
data=False, message='Knowledge bases use different embedding models or does not exist."',
code=settings.RetCode.AUTHENTICATION_ERROR)
embd_mdl = TenantLLMService.model_instance(
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
@ -833,7 +835,7 @@ def retrieval():
if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]:
@ -843,5 +845,5 @@ def retrieval():
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)

View File

@ -19,7 +19,7 @@ from functools import partial
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api import settings
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas
@ -36,7 +36,8 @@ def templates():
@login_required
def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
UserCanvasService.query(user_id=current_user.id)],
key=lambda x: x["update_time"] * -1)
)
@ -48,7 +49,7 @@ def rm():
if not UserCanvasService.query(user_id=current_user.id, id=i):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@ -72,7 +73,7 @@ def save():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req)
@ -98,7 +99,7 @@ def run():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
@ -122,7 +123,8 @@ def run():
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
if stream:
assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
assert isinstance(answer,
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
def sse():
nonlocal answer, cvs
@ -173,7 +175,7 @@ def reset():
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()

View File

@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db.services.document_service import DocumentService
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
from api import settings
from api.utils.api_utils import get_json_result
import hashlib
import re
@manager.route('/list', methods=['POST'])
@login_required
@validate_request("doc_id")
@ -56,7 +57,7 @@ def list_chunk():
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
@ -78,7 +79,7 @@ def list_chunk():
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@ -93,7 +94,7 @@ def get():
tenant_id = tenants[0].tenant_id
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
if chunk is None:
return server_error_response("Chunk not found")
k = []
@ -107,7 +108,7 @@ def get():
except Exception as e:
if str(e).find("NotFoundError") >= 0:
return get_json_result(data=False, message='Chunk not found!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@ -154,7 +155,7 @@ def set():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -169,7 +170,7 @@ def switch():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
search.index_name(doc.tenant_id), doc.kb_id):
return get_data_error_result(message="Index updating failure")
return get_json_result(data=True)
@ -186,7 +187,7 @@ def rm():
e, doc = DocumentService.get_by_id(req["doc_id"])
if not e:
return get_data_error_result(message="Document not found!")
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
return get_data_error_result(message="Index updating failure")
deleted_chunk_ids = req["chunk_ids"]
chunk_number = len(deleted_chunk_ids)
@ -230,7 +231,7 @@ def create():
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
DocumentService.increment_chunk_num(
doc.id, doc.kb_id, c, 1, 0)
@ -265,7 +266,7 @@ def retrieval_test():
else:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
@ -281,7 +282,7 @@ def retrieval_test():
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
similarity_threshold, vector_similarity_weight, top,
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
@ -293,7 +294,7 @@ def retrieval_test():
except Exception as e:
if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
code=RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e)
@ -307,7 +308,7 @@ def knowledge_graph():
"doc_ids": [doc_id],
"knowledge_graph_kwd": ["graph", "mind_map"]
}
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
obj = {"graph": {}, "mind_map": {}}
for id in sres.ids[:2]:
ty = sres.field[id]["knowledge_graph_kwd"]
@ -336,4 +337,3 @@ def knowledge_graph():
obj[ty] = content_json
return get_json_result(data=obj)

View File

@ -25,7 +25,7 @@ from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils.api_utils import get_json_result
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from graphrag.mind_map_extractor import MindMapExtractor
@ -87,7 +87,7 @@ def get():
else:
return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
conv = conv.to_dict()
return get_json_result(data=conv)
except Exception as e:
@ -110,7 +110,7 @@ def rm():
else:
return get_json_result(
data=False, message='Only owner of conversation authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid)
return get_json_result(data=True)
except Exception as e:
@ -125,7 +125,7 @@ def list_convsersation():
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
convs = ConversationService.query(
dialog_id=dialog_id,
order_by=ConversationService.model.create_time,
@ -297,6 +297,7 @@ def thumbup():
def ask_about():
req = request.json
uid = current_user.id
def stream():
nonlocal req, uid
try:
@ -329,7 +330,7 @@ def mindmap():
embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
0.3, 0.3, aggs=False)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output

View File

@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
@ -175,7 +175,7 @@ def rm():
else:
return get_json_result(
data=False, message='Only owner of dialog authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True)

View File

@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.utils import get_uuid
from api.db import FileType, TaskStatus, ParserType, FileSource
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.settings import RetCode, docStoreConn
from api import settings
from api.utils.api_utils import get_json_result
from rag.utils.storage_factory import STORAGE_IMPL
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
@ -49,16 +49,16 @@ def upload():
kb_id = request.form.get("kb_id")
if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
@ -67,7 +67,7 @@ def upload():
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
if err:
return get_json_result(
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True)
@ -78,12 +78,12 @@ def web_crawl():
kb_id = request.form.get("kb_id")
if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
name = request.form.get("name")
url = request.form.get("url")
if not is_valid_url(url):
return get_json_result(
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this knowledgebase!")
@ -145,7 +145,7 @@ def create():
kb_id = req["kb_id"]
if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
try:
e, kb = KnowledgebaseService.get_by_id(kb_id)
@ -179,7 +179,7 @@ def list_docs():
kb_id = request.args.get("kb_id")
if not kb_id:
return get_json_result(
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(
@ -188,7 +188,7 @@ def list_docs():
else:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "")
page_number = int(request.args.get("page", 1))
@ -218,7 +218,7 @@ def docinfos():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
docs = DocumentService.get_by_ids(doc_ids)
return get_json_result(data=list(docs.dicts()))
@ -230,7 +230,7 @@ def thumbnails():
doc_ids = request.args.get("doc_ids").split(",")
if not doc_ids:
return get_json_result(
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
try:
docs = DocumentService.get_thumbnails(doc_ids)
@ -253,13 +253,13 @@ def change_status():
return get_json_result(
data=False,
message='"Status" must be either 0 or 1!',
code=RetCode.ARGUMENT_ERROR)
code=settings.RetCode.ARGUMENT_ERROR)
if not DocumentService.accessible(req["doc_id"], current_user.id):
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
@ -276,7 +276,8 @@ def change_status():
message="Database error (Document update)!")
status = int(req["status"])
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
search.index_name(kb.tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
@ -295,7 +296,7 @@ def rm():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
root_folder = FileService.get_root_folder(current_user.id)
@ -326,7 +327,7 @@ def rm():
errors += str(e)
if errors:
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
return get_json_result(data=True)
@ -341,7 +342,7 @@ def run():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
try:
for id in req["doc_ids"]:
@ -358,8 +359,8 @@ def run():
e, doc = DocumentService.get_by_id(id)
if not e:
return get_data_error_result(message="Document not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
if str(req["run"]) == TaskStatus.RUNNING.value:
TaskService.filter_delete([Task.doc_id == id])
@ -383,7 +384,7 @@ def rename():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
@ -394,7 +395,7 @@ def rename():
return get_json_result(
data=False,
message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR)
code=settings.RetCode.ARGUMENT_ERROR)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
return get_data_error_result(
@ -450,7 +451,7 @@ def change_parser():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
try:
e, doc = DocumentService.get_by_id(req["doc_id"])
@ -483,8 +484,8 @@ def change_parser():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(message="Tenant not found!")
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
return get_json_result(data=True)
except Exception as e:
@ -509,13 +510,13 @@ def get_image(image_id):
def upload_and_parse():
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
@ -529,7 +530,7 @@ def parse():
if url:
if not is_valid_url(url):
return get_json_result(
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
os.makedirs(download_path, exist_ok=True)
from selenium.webdriver import Chrome, ChromeOptions
@ -553,7 +554,7 @@ def parse():
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
txt = FileService.parse_docs(file_objs, current_user.id)

View File

@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
from api.utils import get_uuid
from api.db import FileType
from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api import settings
from api.utils.api_utils import get_json_result
@ -100,7 +100,7 @@ def rm():
file_ids = req["file_ids"]
if not file_ids:
return get_json_result(
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
try:
for file_id in file_ids:
informs = File2DocumentService.get_by_file_id(file_id)

View File

@ -28,7 +28,7 @@ from api.utils import get_uuid
from api.db import FileType, FileSource
from api.db.services import duplicate_name
from api.db.services.file_service import FileService
from api.settings import RetCode
from api import settings
from api.utils.api_utils import get_json_result
from api.utils.file_utils import filename_type
from rag.utils.storage_factory import STORAGE_IMPL
@ -46,13 +46,13 @@ def upload():
if 'file' not in request.files:
return get_json_result(
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
file_objs = request.files.getlist('file')
for file_obj in file_objs:
if file_obj.filename == '':
return get_json_result(
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
file_res = []
try:
for file_obj in file_objs:
@ -134,7 +134,7 @@ def create():
try:
if not FileService.is_parent_folder_exist(pf_id):
return get_json_result(
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
if FileService.query(name=req["name"], parent_id=pf_id):
return get_data_error_result(
message="Duplicated folder name in the same folder.")
@ -299,7 +299,7 @@ def rename():
return get_json_result(
data=False,
message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR)
code=settings.RetCode.ARGUMENT_ERROR)
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
if file.name == req["name"]:
return get_data_error_result(

View File

@ -26,9 +26,8 @@ from api.utils import get_uuid
from api.db import StatusEnum, FileSource
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.db_models import File
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from api.settings import docStoreConn
from api import settings
from rag.nlp import search
@ -68,13 +67,13 @@ def update():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
try:
if not KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"]):
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
if not e:
@ -113,7 +112,7 @@ def detail():
else:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=RetCode.OPERATING_ERROR)
code=settings.RetCode.OPERATING_ERROR)
kb = KnowledgebaseService.get_detail(kb_id)
if not kb:
return get_data_error_result(
@ -148,14 +147,14 @@ def rm():
return get_json_result(
data=False,
message='No authorization.',
code=RetCode.AUTHENTICATION_ERROR
code=settings.RetCode.AUTHENTICATION_ERROR
)
try:
kbs = KnowledgebaseService.query(
created_by=current_user.id, id=req["kb_id"])
if not kbs:
return get_json_result(
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
for doc in DocumentService.query(kb_id=req["kb_id"]):
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
@ -170,7 +169,7 @@ def rm():
message="Database error (Knowledgebase removal)!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)

View File

@ -19,7 +19,7 @@ import json
from flask import request
from flask_login import login_required, current_user
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from api.settings import LIGHTEN
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db import StatusEnum, LLMType
from api.db.db_models import TenantLLM
@ -333,7 +333,7 @@ def my_llms():
@login_required
def list_app():
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)

View File

@ -14,7 +14,7 @@
# limitations under the License.
#
from flask import request
from api.settings import RetCode
from api import settings
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
@ -44,7 +44,7 @@ def create(tenant_id):
kbs = KnowledgebaseService.get_by_ids(ids)
embd_count = list(set([kb.embd_id for kb in kbs]))
if len(embd_count) != 1:
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids
# llm
llm = req.get("llm")
@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
if len(embd_count) != 1 :
return get_result(
message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR)
code=settings.RetCode.AUTHENTICATION_ERROR)
req["kb_ids"] = ids
llm = req.get("llm")
if llm:

View File

@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService, LLMService
from api.db.services.user_service import TenantService
from api.settings import RetCode
from api import settings
from api.utils import get_uuid
from api.utils.api_utils import (
get_result,
@ -255,7 +255,7 @@ def delete(tenant_id):
File2DocumentService.delete_by_document_id(doc.id)
if not KnowledgebaseService.delete_by_id(id):
return get_error_data_result(message="Delete dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS)
return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
)
if not KnowledgebaseService.update_by_id(kb.id, req):
return get_error_data_result(message="Update dataset error.(Database error)")
return get_result(code=RetCode.SUCCESS)
return get_result(code=settings.RetCode.SUCCESS)
@manager.route("/datasets", methods=["GET"])

View File

@ -18,7 +18,7 @@ from flask import request, jsonify
from api.db import LLMType, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler, kg_retrievaler, RetCode
from api import settings
from api.utils.api_utils import validate_request, build_error_result, apikey_required
@ -37,14 +37,14 @@ def retrieval(tenant_id):
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
if kb.tenant_id != tenant_id:
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
question,
embd_mdl,
@ -72,6 +72,6 @@ def retrieval(tenant_id):
if str(e).find("not_found") > 0:
return build_error_result(
message='No chunk found! Check the chunk status please!',
code=RetCode.NOT_FOUND
code=settings.RetCode.NOT_FOUND
)
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)

View File

@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
from rag.nlp import rag_tokenizer
from api.db import LLMType, ParserType
from api.db.services.llm_service import TenantLLMService
from api.settings import kg_retrievaler
from api import settings
import hashlib
import re
from api.utils.api_utils import token_required
@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import RetCode, retrievaler
from api import settings
from api.utils.api_utils import construct_json_result, get_parser_config
from rag.nlp import search
from rag.utils import rmSpace
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL
import os
@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
"""
if "file" not in request.files:
return get_error_data_result(
message="No file part!", code=RetCode.ARGUMENT_ERROR
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
)
file_objs = request.files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_result(
message="No file selected!", code=RetCode.ARGUMENT_ERROR
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
)
# total size
total_size = 0
@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
if total_size > MAX_TOTAL_FILE_SIZE:
return get_result(
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
code=RetCode.ARGUMENT_ERROR,
code=settings.RetCode.ARGUMENT_ERROR,
)
e, kb = KnowledgebaseService.get_by_id(dataset_id)
if not e:
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
err, files = FileService.upload_document(kb, file_objs, tenant_id)
if err:
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
# rename key's name
renamed_doc_list = []
for file in files:
@ -226,7 +225,7 @@ def update_doc(tenant_id, dataset_id, document_id):
):
return get_result(
message="The extension of file can't be changed",
code=RetCode.ARGUMENT_ERROR,
code=settings.RetCode.ARGUMENT_ERROR,
)
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
if d.name == req["name"]:
@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
)
if not e:
return get_error_data_result(message="Document not found!")
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result()
@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
if not file_stream:
return construct_json_result(
message="This file is empty.", code=RetCode.DATA_ERROR
message="This file is empty.", code=settings.RetCode.DATA_ERROR
)
file = BytesIO(file_stream)
# Use send_file with a proper filename and MIME type
@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
errors += str(e)
if errors:
return get_result(message=errors, code=RetCode.SERVER_ERROR)
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
return get_result()
@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
info["chunk_num"] = 0
info["token_num"] = 0
DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
TaskService.filter_delete([Task.doc_id == id])
e, doc = DocumentService.get_by_id(id)
doc = doc.to_dict()
@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
)
info = {"run": "2", "progress": 0, "chunk_num": 0}
DocumentService.update_by_id(id, info)
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
return get_result()
@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
res = {"total": 0, "chunks": [], "doc": renamed_doc}
origin_chunks = []
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
highlight=True)
res["total"] = sres.total
sign = 0
for id in sres.ids:
@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
v, c = embd_mdl.encode([doc.name, req["content"]])
v = 0.1 * v[0] + 0.9 * v[1]
d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
# rename keys
@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
condition = {"doc_id": document_id}
if "chunk_ids" in req:
condition["id"] = req["chunk_ids"]
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
if chunk_number != 0:
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
schema:
type: object
"""
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
if chunk is None:
return get_error_data_result(f"Can't find this chunk {chunk_id}")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
d["q_%d_vec" % len(v)] = v.tolist()
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
return get_result()
@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
if len(embd_nms) != 1:
return get_result(
message='Datasets use different embedding models."',
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
)
if "question" not in req:
return get_error_data_result("`question` is required.")
@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question)
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
ranks = retr.retrieval(
question,
embd_mdl,
@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
if str(e).find("not_found") > 0:
return get_result(
message="No chunk found! Check the chunk status please!",
code=RetCode.DATA_ERROR,
code=settings.RetCode.DATA_ERROR,
)
return server_error_response(e)

View File

@ -22,7 +22,7 @@ from api.db.db_models import APIToken
from api.db.services.api_service import APITokenService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import UserTenantService
from api.settings import DATABASE_TYPE
from api import settings
from api.utils import current_timestamp, datetime_format
from api.utils.api_utils import (
get_json_result,
@ -31,7 +31,6 @@ from api.utils.api_utils import (
generate_confirmation_token,
)
from api.versions import get_ragflow_version
from api.settings import docStoreConn
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
from timeit import default_timer as timer
@ -98,7 +97,7 @@ def status():
res = {}
st = timer()
try:
res["doc_store"] = docStoreConn.health()
res["doc_store"] = settings.docStoreConn.health()
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
except Exception as e:
res["doc_store"] = {
@ -128,13 +127,13 @@ def status():
try:
KnowledgebaseService.get_by_id("x")
res["database"] = {
"database": DATABASE_TYPE.lower(),
"database": settings.DATABASE_TYPE.lower(),
"status": "green",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
}
except Exception as e:
res["database"] = {
"database": DATABASE_TYPE.lower(),
"database": settings.DATABASE_TYPE.lower(),
"status": "red",
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
"error": str(e),

View File

@ -38,20 +38,7 @@ from api.utils import (
datetime_format,
)
from api.db import UserTenantRole, FileType
from api.settings import (
RetCode,
GITHUB_OAUTH,
FEISHU_OAUTH,
CHAT_MDL,
EMBEDDING_MDL,
ASR_MDL,
IMAGE2TEXT_MDL,
PARSERS,
API_KEY,
LLM_FACTORY,
LLM_BASE_URL,
RERANK_MDL,
)
from api import settings
from api.db.services.user_service import UserService, TenantService, UserTenantService
from api.db.services.file_service import FileService
from api.utils.api_utils import get_json_result, construct_response
@ -90,7 +77,7 @@ def login():
"""
if not request.json:
return get_json_result(
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
)
email = request.json.get("email", "")
@ -98,7 +85,7 @@ def login():
if not users:
return get_json_result(
data=False,
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
message=f"Email: {email} is not registered!",
)
@ -107,7 +94,7 @@ def login():
password = decrypt(password)
except BaseException:
return get_json_result(
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
)
user = UserService.query_user(email, password)
@ -123,7 +110,7 @@ def login():
else:
return get_json_result(
data=False,
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
message="Email and password do not match!",
)
@ -150,10 +137,10 @@ def github_callback():
import requests
res = requests.post(
GITHUB_OAUTH.get("url"),
settings.GITHUB_OAUTH.get("url"),
data={
"client_id": GITHUB_OAUTH.get("client_id"),
"client_secret": GITHUB_OAUTH.get("secret_key"),
"client_id": settings.GITHUB_OAUTH.get("client_id"),
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
"code": request.args.get("code"),
},
headers={"Accept": "application/json"},
@ -235,11 +222,11 @@ def feishu_callback():
import requests
app_access_token_res = requests.post(
FEISHU_OAUTH.get("app_access_token_url"),
settings.FEISHU_OAUTH.get("app_access_token_url"),
data=json.dumps(
{
"app_id": FEISHU_OAUTH.get("app_id"),
"app_secret": FEISHU_OAUTH.get("app_secret"),
"app_id": settings.FEISHU_OAUTH.get("app_id"),
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
}
),
headers={"Content-Type": "application/json; charset=utf-8"},
@ -249,10 +236,10 @@ def feishu_callback():
return redirect("/?error=%s" % app_access_token_res)
res = requests.post(
FEISHU_OAUTH.get("user_access_token_url"),
settings.FEISHU_OAUTH.get("user_access_token_url"),
data=json.dumps(
{
"grant_type": FEISHU_OAUTH.get("grant_type"),
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
"code": request.args.get("code"),
}
),
@ -409,7 +396,7 @@ def setting_user():
):
return get_json_result(
data=False,
code=RetCode.AUTHENTICATION_ERROR,
code=settings.RetCode.AUTHENTICATION_ERROR,
message="Password error!",
)
@ -438,7 +425,7 @@ def setting_user():
except Exception as e:
logging.exception(e)
return get_json_result(
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
)
@ -497,12 +484,12 @@ def user_register(user_id, user):
tenant = {
"id": user_id,
"name": user["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL,
"rerank_id": RERANK_MDL,
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL,
"rerank_id": settings.RERANK_MDL,
}
usr_tenant = {
"tenant_id": user_id,
@ -522,15 +509,15 @@ def user_register(user_id, user):
"location": "",
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": LLM_FACTORY,
"llm_factory": settings.LLM_FACTORY,
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": API_KEY,
"api_base": LLM_BASE_URL,
"api_key": settings.API_KEY,
"api_base": settings.LLM_BASE_URL,
}
)
@ -582,7 +569,7 @@ def user_add():
return get_json_result(
data=False,
message=f"Invalid email address: {email_address}!",
code=RetCode.OPERATING_ERROR,
code=settings.RetCode.OPERATING_ERROR,
)
# Check if the email address is already used
@ -590,7 +577,7 @@ def user_add():
return get_json_result(
data=False,
message=f"Email: {email_address} has already registered!",
code=RetCode.OPERATING_ERROR,
code=settings.RetCode.OPERATING_ERROR,
)
# Construct user info data
@ -625,7 +612,7 @@ def user_add():
return get_json_result(
data=False,
message=f"User registration failure, error: {str(e)}",
code=RetCode.EXCEPTION_ERROR,
code=settings.RetCode.EXCEPTION_ERROR,
)

View File

@ -31,7 +31,7 @@ from peewee import (
)
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
from api.db import SerializedType, ParserType
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
from api import settings
from api import utils
def singleton(cls, *args, **kw):
@ -62,7 +62,7 @@ class TextFieldType(Enum):
class LongTextField(TextField):
field_type = TextFieldType[DATABASE_TYPE.upper()].value
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
class JSONField(LongTextField):
@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
@singleton
class BaseDataBase:
def __init__(self):
database_config = DATABASE.copy()
database_config = settings.DATABASE.copy()
db_name = database_config.pop("name")
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
logging.info('init database on cluster mode successfully')
class PostgresDatabaseLock:
@ -385,7 +385,7 @@ class DatabaseLock(Enum):
DB = BaseDataBase().database_connection
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
def close_connection():
@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
return self.email
def get_id(self):
jwt = Serializer(secret_key=SECRET_KEY)
jwt = Serializer(secret_key=settings.SECRET_KEY)
return jwt.dumps(str(self.access_token))
class Meta:
@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
def migrate_db():
with DB.transaction():
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
try:
migrate(
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",

View File

@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
from api import settings
from api.utils.file_utils import get_project_base_directory
@ -51,11 +51,11 @@ def init_superuser():
tenant = {
"id": user_info["id"],
"name": user_info["nickname"] + "s Kingdom",
"llm_id": CHAT_MDL,
"embd_id": EMBEDDING_MDL,
"asr_id": ASR_MDL,
"parser_ids": PARSERS,
"img2txt_id": IMAGE2TEXT_MDL
"llm_id": settings.CHAT_MDL,
"embd_id": settings.EMBEDDING_MDL,
"asr_id": settings.ASR_MDL,
"parser_ids": settings.PARSERS,
"img2txt_id": settings.IMAGE2TEXT_MDL
}
usr_tenant = {
"tenant_id": user_info["id"],
@ -64,10 +64,11 @@ def init_superuser():
"role": UserTenantRole.OWNER
}
tenant_llm = []
for llm in LLMService.query(fid=LLM_FACTORY):
for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
"api_key": API_KEY, "api_base": LLM_BASE_URL})
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
if not UserService.save(**user_info):
logging.error("can't init admin.")

View File

@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import retrievaler, kg_retrievaler
from api import settings
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
logging.debug(f"{question} get SQL(refined): {sql}")
tried_times += 1
return retrievaler.sql_retrieval(sql, format="json"), sql
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
tbl, sql = get_table()
if tbl is None:
@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
embd_nms = list(set([kb.embd_id for kb in kbs]))
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
retr = retrievaler if not is_kg else kg_retrievaler
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)

View File

@ -26,7 +26,7 @@ from io import BytesIO
from peewee import fn
from api.db.db_utils import bulk_insert_into_db
from api.settings import docStoreConn
from api import settings
from api.utils import current_timestamp, get_format_time, get_uuid
from graphrag.mind_map_extractor import MindMapExtractor
from rag.settings import SVR_QUEUE_NAME
@ -108,7 +108,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def remove_document(cls, doc, tenant_id):
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
cls.clear_chunk_num(doc.id)
return cls.delete_by_id(doc.id)
@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not docStoreConn.indexExist(idxnm, kb_id):
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
if not settings.docStoreConn.indexExist(idxnm, kb_id):
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
try_create_idx = False
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
DocumentService.increment_chunk_num(
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

View File

@ -33,12 +33,10 @@ import traceback
from concurrent.futures import ThreadPoolExecutor
from werkzeug.serving import run_simple
from api import settings
from api.apps import app
from api.db.runtime_config import RuntimeConfig
from api.db.services.document_service import DocumentService
from api.settings import (
HOST, HTTP_PORT
)
from api import utils
from api.db.db_models import init_database_tables as init_web_db
@ -72,6 +70,7 @@ if __name__ == '__main__':
f'project base: {utils.file_utils.get_project_base_directory()}'
)
show_configs()
settings.init_settings()
# init db
init_web_db()
@ -96,7 +95,7 @@ if __name__ == '__main__':
logging.info("run on debug mode")
RuntimeConfig.init_env()
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
thread = ThreadPoolExecutor(max_workers=1)
thread.submit(update_progress)
@ -105,8 +104,8 @@ if __name__ == '__main__':
try:
logging.info("RAGFlow HTTP server start...")
run_simple(
hostname=HOST,
port=HTTP_PORT,
hostname=settings.HOST_IP,
port=settings.HOST_PORT,
application=app,
threaded=True,
use_reloader=RuntimeConfig.DEBUG,

View File

@ -30,12 +30,46 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
REQUEST_WAIT_SEC = 2
REQUEST_MAX_WAIT_SEC = 300
LLM = None
LLM_FACTORY = None
LLM_BASE_URL = None
CHAT_MDL = ""
EMBEDDING_MDL = ""
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
API_KEY = None
PARSERS = None
HOST_IP = None
HOST_PORT = None
SECRET_KEY = None
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
# authentication
AUTHENTICATION_CONF = None
# client
CLIENT_AUTHENTICATION = None
HTTP_APP_KEY = None
GITHUB_OAUTH = None
FEISHU_OAUTH = None
DOC_ENGINE = None
docStoreConn = None
retrievaler = None
kg_retrievaler = None
def init_settings():
global LLM, LLM_FACTORY, LLM_BASE_URL
LLM = get_base_config("user_default_llm", {})
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
LLM_BASE_URL = LLM.get("base_url")
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
if not LIGHTEN:
default_llm = {
"Tongyi-Qianwen": {
@ -102,21 +136,20 @@ if not LIGHTEN:
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key", "")
PARSERS = LLM.get(
"parsers",
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
SECRET_KEY = get_base_config(
RAG_FLOW_SERVICE_NAME,
{}).get("secret_key", str(date.today()))
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
# authentication
AUTHENTICATION_CONF = get_base_config("authentication", {})
@ -128,6 +161,7 @@ HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
if DOC_ENGINE == "elasticsearch":
docStoreConn = rag.utils.es_conn.ESConnection()
@ -139,6 +173,15 @@ else:
retrievaler = search.Dealer(docStoreConn)
kg_retrievaler = kg_search.KGSearch(docStoreConn)
def get_host_ip():
global HOST_IP
return HOST_IP
def get_host_port():
global HOST_PORT
return HOST_PORT
class CustomEnum(Enum):
@classmethod

View File

@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
from werkzeug.http import HTTP_STATUS_CODES
from api.db.db_models import APIToken
from api.settings import (
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
)
from api.settings import RetCode
from api import settings
from api import settings
from api.utils import CustomJSONEncoder, get_uuid
from api.utils import json_dumps
@ -59,13 +57,13 @@ def request(**kwargs):
{}).items()}
prepped = requests.Request(**kwargs).prepare()
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
timestamp = str(round(time() * 1000))
nonce = str(uuid1())
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
timestamp.encode('ascii'),
nonce.encode('ascii'),
HTTP_APP_KEY.encode('ascii'),
settings.HTTP_APP_KEY.encode('ascii'),
prepped.path_url.encode('ascii'),
prepped.body if kwargs.get('json') else b'',
urlencode(
@ -79,7 +77,7 @@ def request(**kwargs):
prepped.headers.update({
'TIMESTAMP': timestamp,
'NONCE': nonce,
'APP-KEY': HTTP_APP_KEY,
'APP-KEY': settings.HTTP_APP_KEY,
'SIGNATURE': signature,
})
@ -89,7 +87,7 @@ def request(**kwargs):
def get_exponential_backoff_interval(retries, full_jitter=False):
"""Calculate the exponential backoff wait time."""
# Will be zero if factor equals 0
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
# Full jitter according to
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
if full_jitter:
@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
return max(0, countdown)
def get_data_error_result(code=RetCode.DATA_ERROR,
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
message='Sorry! Data missing!'):
import re
result_dict = {
@ -126,8 +124,8 @@ def server_error_response(e):
pass
if len(e.args) > 1:
return get_json_result(
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def error_response(response_code, message=None):
@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
error_string += "required argument values: {}".format(
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
return get_json_result(
code=RetCode.ARGUMENT_ERROR, message=error_string)
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
return func(*_args, **_kwargs)
return decorated_function
@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
return send_file(f, as_attachment=True, attachment_filename=filename)
def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
response = {"code": code, "message": message, "data": data}
return jsonify(response)
@ -204,7 +202,7 @@ def apikey_required(func):
objs = APIToken.query(token=token)
if not objs:
return build_error_result(
message='API-KEY is invalid!', code=RetCode.FORBIDDEN
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)
@ -212,14 +210,14 @@ def apikey_required(func):
return decorated_function
def build_error_result(code=RetCode.FORBIDDEN, message='success'):
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
response = {"code": code, "message": message}
response = jsonify(response)
response.status_code = code
return response
def construct_response(code=RetCode.SUCCESS,
def construct_response(code=settings.RetCode.SUCCESS,
message='success', data=None, auth=None):
result_dict = {"code": code, "message": message, "data": data}
response_dict = {}
@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
return response
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
import re
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
response = {}
@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
return jsonify(response)
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
if data is None:
return jsonify({"code": code, "message": message})
else:
@ -262,12 +260,12 @@ def construct_error_response(e):
logging.exception(e)
try:
if e.code == 401:
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
except BaseException:
pass
if len(e.args) > 1:
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
def token_required(func):
@ -280,7 +278,7 @@ def token_required(func):
objs = APIToken.query(token=token)
if not objs:
return get_json_result(
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
)
kwargs['tenant_id'] = objs[0].tenant_id
return func(*args, **kwargs)
@ -288,7 +286,7 @@ def token_required(func):
return decorated_function
def get_result(code=RetCode.SUCCESS, message="", data=None):
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
if code == 0:
if data is not None:
response = {"code": code, "data": data}
@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
return jsonify(response)
def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
):
import re
result_dict = {

View File

@ -24,7 +24,7 @@ import numpy as np
from timeit import default_timer as timer
from pypdf import PdfReader as pdf2_read
from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import rag_tokenizer
@ -41,7 +41,7 @@ class RAGFlowPdfParser:
self.tbl_det = TableStructureRecognizer()
self.updown_cnt_mdl = xgb.Booster()
if not LIGHTEN:
if not settings.LIGHTEN:
try:
import torch
if torch.cuda.is_available():

View File

@ -252,13 +252,13 @@ if __name__ == "__main__":
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
info = {
"input_text": docs,
"entity_specs": "organization, person",

View File

@ -30,14 +30,14 @@ if __name__ == "__main__":
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.settings import retrievaler
from api import settings
from api.db.services.knowledgebase_service import KnowledgebaseService
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
docs = [d["content_with_weight"] for d in
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
graph = ex(docs)
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))

View File

@ -23,7 +23,7 @@ from collections import defaultdict
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.utils import get_uuid
from rag.nlp import tokenize, search
from ranx import evaluate
@ -52,7 +52,7 @@ class Benchmark:
run = defaultdict(dict)
query_list = list(qrels.keys())
for query in query_list:
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
0.0, self.vector_similarity_weight)
if len(ranks["chunks"]) == 0:
print(f"deleted query: {query}")
@ -81,9 +81,9 @@ class Benchmark:
def init_index(self, vector_size: int):
if self.initialized_index:
return
if docStoreConn.indexExist(self.index_name, self.kb_id):
docStoreConn.deleteIdx(self.index_name, self.kb_id)
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
self.initialized_index = True
def ms_marco_index(self, file_path, index_name):
@ -118,13 +118,13 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
docs = []
if docs:
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name, self.kb_id)
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
return qrels, texts
def trivia_qa_index(self, file_path, index_name):
@ -159,12 +159,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs,self.index_name)
settings.docStoreConn.insert(docs,self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def miracl_index(self, file_path, corpus_path, index_name):
@ -214,12 +214,12 @@ class Benchmark:
docs_count += len(docs)
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
docs = []
docs, vector_size = self.embedding(docs)
self.init_index(vector_size)
docStoreConn.insert(docs, self.index_name)
settings.docStoreConn.insert(docs, self.index_name)
return qrels, texts
def save_results(self, qrels, run, texts, dataset, file_path):

View File

@ -28,7 +28,7 @@ from openai import OpenAI
import numpy as np
import asyncio
from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import google.generativeai as genai
@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
^_-
"""
if not LIGHTEN and not DefaultEmbedding._model:
if not settings.LIGHTEN and not DefaultEmbedding._model:
with DefaultEmbedding._model_lock:
from FlagEmbedding import FlagModel
import torch
@ -248,7 +248,7 @@ class FastEmbed(Base):
threads: Optional[int] = None,
**kwargs,
):
if not LIGHTEN and not FastEmbed._model:
if not settings.LIGHTEN and not FastEmbed._model:
from fastembed import TextEmbedding
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
_client = None
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
if not LIGHTEN and not YoudaoEmbed._client:
if not settings.LIGHTEN and not YoudaoEmbed._client:
from BCEmbedding import EmbeddingModel as qanthing
try:
logging.info("LOADING BCE...")

View File

@ -23,7 +23,7 @@ import os
from abc import ABC
import numpy as np
from api.settings import LIGHTEN
from api import settings
from api.utils.file_utils import get_home_cache_dir
from rag.utils import num_tokens_from_string, truncate
import json
@ -57,7 +57,7 @@ class DefaultRerank(Base):
^_-
"""
if not LIGHTEN and not DefaultRerank._model:
if not settings.LIGHTEN and not DefaultRerank._model:
import torch
from FlagEmbedding import FlagReranker
with DefaultRerank._model_lock:
@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
_model_lock = threading.Lock()
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
if not LIGHTEN and not YoudaoRerank._model:
if not settings.LIGHTEN and not YoudaoRerank._model:
from BCEmbedding import RerankerModel
with YoudaoRerank._model_lock:
if not YoudaoRerank._model:

View File

@ -16,6 +16,7 @@
import logging
import sys
from api.utils.log_utils import initRootLogger
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
initRootLogger(f"task_executor_{CONSUMER_NO}")
for module in ["pdfminer"]:
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler, docStoreConn
from api import settings
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
knowledge_graph, email
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
@ -88,6 +90,7 @@ PENDING_TASKS = 0
HEAD_CREATED_AT = ""
HEAD_DETAIL = ""
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD
if prog is not None and prog < 0:
@ -171,7 +174,8 @@ def build(row):
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
logging.exception(
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
return
except Exception as e:
if re.search("(No such file|not found)", str(e)):
@ -226,7 +230,8 @@ def build(row):
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
el += timer() - st
except Exception:
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
logging.exception(
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
del d["image"]
@ -262,7 +267,7 @@ def build(row):
def init_kb(row, vector_size: int):
idxnm = search.index_name(row["tenant_id"])
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
def embedding(docs, mdl, parser_config=None, callback=None):
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vector_size = len(vts[0])
vctr_nm = "q_%d_vec" % vector_size
chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
raptor = Raptor(
@ -384,7 +390,8 @@ def main():
# TODO: exception handler
## set_progress(r["did"], -1, "ERROR: ")
callback(
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
timer() - st)
)
st = timer()
try:
@ -403,18 +410,18 @@ def main():
es_r = ""
es_bulk_size = 4
for b in range(0, len(cks), es_bulk_size):
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
if b % 128 == 0:
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r:
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
logging.error('Insert chunk error: ' + str(es_r))
else:
if TaskService.do_cancel(r["id"]):
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
continue
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
callback(1., "Done!")