mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-21 21:50:02 +08:00
Move settings initialization after module init phase (#3438)
### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
parent
ac033b62cf
commit
1e90a1bf36
@ -19,7 +19,7 @@ import pandas as pd
|
|||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.dialog_service import message_fit_in
|
from api.db.services.dialog_service import message_fit_in
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api import settings
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
|
||||||
@ -63,14 +63,16 @@ class Generate(ComponentBase):
|
|||||||
component_name = "Generate"
|
component_name = "Generate"
|
||||||
|
|
||||||
def get_dependent_components(self):
|
def get_dependent_components(self):
|
||||||
cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
cpnts = [para["component_id"] for para in self._param.parameters if
|
||||||
|
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
||||||
return cpnts
|
return cpnts
|
||||||
|
|
||||||
def set_cite(self, retrieval_res, answer):
|
def set_cite(self, retrieval_res, answer):
|
||||||
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
||||||
if "empty_response" in retrieval_res.columns:
|
if "empty_response" in retrieval_res.columns:
|
||||||
retrieval_res["empty_response"].fillna("", inplace=True)
|
retrieval_res["empty_response"].fillna("", inplace=True)
|
||||||
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
answer, idx = settings.retrievaler.insert_citations(answer,
|
||||||
|
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
||||||
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
||||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
||||||
self._canvas.get_embedding_model()), tkweight=0.7,
|
self._canvas.get_embedding_model()), tkweight=0.7,
|
||||||
@ -127,12 +129,14 @@ class Generate(ComponentBase):
|
|||||||
else:
|
else:
|
||||||
if cpn.component_name.lower() == "retrieval":
|
if cpn.component_name.lower() == "retrieval":
|
||||||
retrieval_res.append(out)
|
retrieval_res.append(out)
|
||||||
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
|
kwargs[para["key"]] = " - " + "\n - ".join(
|
||||||
|
[o if isinstance(o, str) else str(o) for o in out["content"]])
|
||||||
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
||||||
|
|
||||||
if retrieval_res:
|
if retrieval_res:
|
||||||
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
||||||
else: retrieval_res = pd.DataFrame([])
|
else:
|
||||||
|
retrieval_res = pd.DataFrame([])
|
||||||
|
|
||||||
for n, v in kwargs.items():
|
for n, v in kwargs.items():
|
||||||
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
||||||
|
@ -21,7 +21,7 @@ import pandas as pd
|
|||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api import settings
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
|
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
|
|||||||
if self._param.rerank_id:
|
if self._param.rerank_id:
|
||||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||||
|
|
||||||
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
||||||
1, self._param.top_n,
|
1, self._param.top_n,
|
||||||
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
||||||
aggs=False, rerank_mdl=rerank_mdl)
|
aggs=False, rerank_mdl=rerank_mdl)
|
||||||
|
@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
|
|||||||
|
|
||||||
from flask_session import Session
|
from flask_session import Session
|
||||||
from flask_login import LoginManager
|
from flask_login import LoginManager
|
||||||
from api.settings import SECRET_KEY
|
from api import settings
|
||||||
from api.settings import API_VERSION
|
|
||||||
from api.utils.api_utils import server_error_response
|
from api.utils.api_utils import server_error_response
|
||||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||||
|
|
||||||
@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
|
|||||||
app.json_encoder = CustomJSONEncoder
|
app.json_encoder = CustomJSONEncoder
|
||||||
app.errorhandler(Exception)(server_error_response)
|
app.errorhandler(Exception)(server_error_response)
|
||||||
|
|
||||||
|
|
||||||
## convince for dev and debug
|
## convince for dev and debug
|
||||||
# app.config["LOGIN_DISABLED"] = True
|
# app.config["LOGIN_DISABLED"] = True
|
||||||
app.config["SESSION_PERMANENT"] = False
|
app.config["SESSION_PERMANENT"] = False
|
||||||
@ -121,7 +119,7 @@ def register_page(page_path):
|
|||||||
spec.loader.exec_module(page)
|
spec.loader.exec_module(page)
|
||||||
page_name = getattr(page, "page_name", page_name)
|
page_name = getattr(page, "page_name", page_name)
|
||||||
url_prefix = (
|
url_prefix = (
|
||||||
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
|
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||||
@ -141,7 +139,7 @@ client_urls_prefix = [
|
|||||||
|
|
||||||
@login_manager.request_loader
|
@login_manager.request_loader
|
||||||
def load_user(web_request):
|
def load_user(web_request):
|
||||||
jwt = Serializer(secret_key=SECRET_KEY)
|
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||||
authorization = web_request.headers.get("Authorization")
|
authorization = web_request.headers.get("Authorization")
|
||||||
if authorization:
|
if authorization:
|
||||||
try:
|
try:
|
||||||
|
@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.task_service import queue_tasks, TaskService
|
from api.db.services.task_service import queue_tasks, TaskService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.settings import RetCode, retrievaler
|
from api import settings
|
||||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||||
generate_confirmation_token
|
generate_confirmation_token
|
||||||
@ -141,7 +141,7 @@ def set_conversation():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
req = request.json
|
req = request.json
|
||||||
try:
|
try:
|
||||||
if objs[0].source == "agent":
|
if objs[0].source == "agent":
|
||||||
@ -183,7 +183,7 @@ def completion():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
req = request.json
|
req = request.json
|
||||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@ -347,7 +347,7 @@ def get(conversation_id):
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||||
@ -357,7 +357,7 @@ def get(conversation_id):
|
|||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||||
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
||||||
code=RetCode.AUTHENTICATION_ERROR)
|
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
for referenct_i in conv['reference']:
|
for referenct_i in conv['reference']:
|
||||||
if referenct_i is None or len(referenct_i) == 0:
|
if referenct_i is None or len(referenct_i) == 0:
|
||||||
@ -378,7 +378,7 @@ def upload():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
kb_name = request.form.get("kb_name").strip()
|
kb_name = request.form.get("kb_name").strip()
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
@ -394,12 +394,12 @@ def upload():
|
|||||||
|
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file = request.files['file']
|
file = request.files['file']
|
||||||
if file.filename == '':
|
if file.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
root_folder = FileService.get_root_folder(tenant_id)
|
root_folder = FileService.get_root_folder(tenant_id)
|
||||||
pf_id = root_folder["id"]
|
pf_id = root_folder["id"]
|
||||||
@ -490,17 +490,17 @@ def upload_parse():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file_objs = request.files.getlist('file')
|
file_objs = request.files.getlist('file')
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == '':
|
if file_obj.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||||
return get_json_result(data=doc_ids)
|
return get_json_result(data=doc_ids)
|
||||||
@ -513,7 +513,7 @@ def list_chunks():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
req = request.json
|
req = request.json
|
||||||
|
|
||||||
@ -531,7 +531,7 @@ def list_chunks():
|
|||||||
)
|
)
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
|
|
||||||
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||||
res = [
|
res = [
|
||||||
{
|
{
|
||||||
"content": res_item["content_with_weight"],
|
"content": res_item["content_with_weight"],
|
||||||
@ -553,7 +553,7 @@ def list_kb_docs():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
req = request.json
|
req = request.json
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
@ -585,6 +585,7 @@ def list_kb_docs():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/document/infos', methods=['POST'])
|
@manager.route('/document/infos', methods=['POST'])
|
||||||
@validate_request("doc_ids")
|
@validate_request("doc_ids")
|
||||||
def docinfos():
|
def docinfos():
|
||||||
@ -592,7 +593,7 @@ def docinfos():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
req = request.json
|
req = request.json
|
||||||
doc_ids = req["doc_ids"]
|
doc_ids = req["doc_ids"]
|
||||||
docs = DocumentService.get_by_ids(doc_ids)
|
docs = DocumentService.get_by_ids(doc_ids)
|
||||||
@ -606,7 +607,7 @@ def document_rm():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
tenant_id = objs[0].tenant_id
|
tenant_id = objs[0].tenant_id
|
||||||
req = request.json
|
req = request.json
|
||||||
@ -653,7 +654,7 @@ def document_rm():
|
|||||||
errors += str(e)
|
errors += str(e)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
@ -668,7 +669,7 @@ def completion_faq():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@ -805,7 +806,7 @@ def retrieval():
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
req = request.json
|
req = request.json
|
||||||
kb_ids = req.get("kb_id", [])
|
kb_ids = req.get("kb_id", [])
|
||||||
@ -822,7 +823,8 @@ def retrieval():
|
|||||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||||
if len(embd_nms) != 1:
|
if len(embd_nms) != 1:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
|
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||||
|
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = TenantLLMService.model_instance(
|
||||||
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
||||||
@ -833,7 +835,7 @@ def retrieval():
|
|||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl)
|
doc_ids, rerank_mdl=rerank_mdl)
|
||||||
for c in ranks["chunks"]:
|
for c in ranks["chunks"]:
|
||||||
@ -843,5 +845,5 @@ def retrieval():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
code=RetCode.DATA_ERROR)
|
code=settings.RetCode.DATA_ERROR)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
@ -19,7 +19,7 @@ from functools import partial
|
|||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||||
from api.settings import RetCode
|
from api import settings
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
@ -36,7 +36,8 @@ def templates():
|
|||||||
@login_required
|
@login_required
|
||||||
def canvas_list():
|
def canvas_list():
|
||||||
return get_json_result(data=sorted([c.to_dict() for c in \
|
return get_json_result(data=sorted([c.to_dict() for c in \
|
||||||
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
|
UserCanvasService.query(user_id=current_user.id)],
|
||||||
|
key=lambda x: x["update_time"] * -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ def rm():
|
|||||||
if not UserCanvasService.query(user_id=current_user.id, id=i):
|
if not UserCanvasService.query(user_id=current_user.id, id=i):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
UserCanvasService.delete_by_id(i)
|
UserCanvasService.delete_by_id(i)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ def save():
|
|||||||
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
UserCanvasService.update_by_id(req["id"], req)
|
UserCanvasService.update_by_id(req["id"], req)
|
||||||
return get_json_result(data=req)
|
return get_json_result(data=req)
|
||||||
|
|
||||||
@ -98,7 +99,7 @@ def run():
|
|||||||
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
@ -122,7 +123,8 @@ def run():
|
|||||||
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
assert isinstance(answer,
|
||||||
|
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
||||||
|
|
||||||
def sse():
|
def sse():
|
||||||
nonlocal answer, cvs
|
nonlocal answer, cvs
|
||||||
@ -173,7 +175,7 @@ def reset():
|
|||||||
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of canvas authorized for this operation.',
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
|
@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
|
|||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
|
from api import settings
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['POST'])
|
@manager.route('/list', methods=['POST'])
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request("doc_id")
|
@validate_request("doc_id")
|
||||||
@ -56,7 +57,7 @@ def list_chunk():
|
|||||||
}
|
}
|
||||||
if "available_int" in req:
|
if "available_int" in req:
|
||||||
query["available_int"] = int(req["available_int"])
|
query["available_int"] = int(req["available_int"])
|
||||||
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
d = {
|
d = {
|
||||||
@ -78,7 +79,7 @@ def list_chunk():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found!',
|
return get_json_result(data=False, message='No chunk found!',
|
||||||
code=RetCode.DATA_ERROR)
|
code=settings.RetCode.DATA_ERROR)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ def get():
|
|||||||
tenant_id = tenants[0].tenant_id
|
tenant_id = tenants[0].tenant_id
|
||||||
|
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||||
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
return server_error_response("Chunk not found")
|
return server_error_response("Chunk not found")
|
||||||
k = []
|
k = []
|
||||||
@ -107,7 +108,7 @@ def get():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("NotFoundError") >= 0:
|
if str(e).find("NotFoundError") >= 0:
|
||||||
return get_json_result(data=False, message='Chunk not found!',
|
return get_json_result(data=False, message='Chunk not found!',
|
||||||
code=RetCode.DATA_ERROR)
|
code=settings.RetCode.DATA_ERROR)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@ -154,7 +155,7 @@ def set():
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -169,7 +170,7 @@ def switch():
|
|||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
||||||
search.index_name(doc.tenant_id), doc.kb_id):
|
search.index_name(doc.tenant_id), doc.kb_id):
|
||||||
return get_data_error_result(message="Index updating failure")
|
return get_data_error_result(message="Index updating failure")
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
@ -186,7 +187,7 @@ def rm():
|
|||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
||||||
return get_data_error_result(message="Index updating failure")
|
return get_data_error_result(message="Index updating failure")
|
||||||
deleted_chunk_ids = req["chunk_ids"]
|
deleted_chunk_ids = req["chunk_ids"]
|
||||||
chunk_number = len(deleted_chunk_ids)
|
chunk_number = len(deleted_chunk_ids)
|
||||||
@ -230,7 +231,7 @@ def create():
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc.id, doc.kb_id, c, 1, 0)
|
doc.id, doc.kb_id, c, 1, 0)
|
||||||
@ -265,7 +266,7 @@ def retrieval_test():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
@ -281,7 +282,7 @@ def retrieval_test():
|
|||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
||||||
@ -293,7 +294,7 @@ def retrieval_test():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
code=RetCode.DATA_ERROR)
|
code=settings.RetCode.DATA_ERROR)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@ -307,7 +308,7 @@ def knowledge_graph():
|
|||||||
"doc_ids": [doc_id],
|
"doc_ids": [doc_id],
|
||||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||||
}
|
}
|
||||||
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||||
obj = {"graph": {}, "mind_map": {}}
|
obj = {"graph": {}, "mind_map": {}}
|
||||||
for id in sres.ids[:2]:
|
for id in sres.ids[:2]:
|
||||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||||
@ -336,4 +337,3 @@ def knowledge_graph():
|
|||||||
obj[ty] = content_json
|
obj[ty] = content_json
|
||||||
|
|
||||||
return get_json_result(data=obj)
|
return get_json_result(data=obj)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from api.db import LLMType
|
|||||||
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
||||||
from api.settings import RetCode, retrievaler
|
from api import settings
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from graphrag.mind_map_extractor import MindMapExtractor
|
from graphrag.mind_map_extractor import MindMapExtractor
|
||||||
@ -87,7 +87,7 @@ def get():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of conversation authorized for this operation.',
|
data=False, message='Only owner of conversation authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
return get_json_result(data=conv)
|
return get_json_result(data=conv)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -110,7 +110,7 @@ def rm():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of conversation authorized for this operation.',
|
data=False, message='Only owner of conversation authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
ConversationService.delete_by_id(cid)
|
ConversationService.delete_by_id(cid)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -125,7 +125,7 @@ def list_convsersation():
|
|||||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of dialog authorized for this operation.',
|
data=False, message='Only owner of dialog authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
convs = ConversationService.query(
|
convs = ConversationService.query(
|
||||||
dialog_id=dialog_id,
|
dialog_id=dialog_id,
|
||||||
order_by=ConversationService.model.create_time,
|
order_by=ConversationService.model.create_time,
|
||||||
@ -297,6 +297,7 @@ def thumbup():
|
|||||||
def ask_about():
|
def ask_about():
|
||||||
req = request.json
|
req = request.json
|
||||||
uid = current_user.id
|
uid = current_user.id
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
@ -329,7 +330,7 @@ def mindmap():
|
|||||||
embd_mdl = TenantLLMService.model_instance(
|
embd_mdl = TenantLLMService.model_instance(
|
||||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
||||||
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
||||||
0.3, 0.3, aggs=False)
|
0.3, 0.3, aggs=False)
|
||||||
mindmap = MindMapExtractor(chat_mdl)
|
mindmap = MindMapExtractor(chat_mdl)
|
||||||
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
||||||
|
@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.settings import RetCode
|
from api import settings
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
@ -175,7 +175,7 @@ def rm():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of dialog authorized for this operation.',
|
data=False, message='Only owner of dialog authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
||||||
DialogService.update_many_by_id(dialog_list)
|
DialogService.update_many_by_id(dialog_list)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import FileType, TaskStatus, ParserType, FileSource
|
from api.db import FileType, TaskStatus, ParserType, FileSource
|
||||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||||
from api.settings import RetCode, docStoreConn
|
from api import settings
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
||||||
@ -49,16 +49,16 @@ def upload():
|
|||||||
kb_id = request.form.get("kb_id")
|
kb_id = request.form.get("kb_id")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file_objs = request.files.getlist('file')
|
file_objs = request.files.getlist('file')
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == '':
|
if file_obj.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
@ -67,7 +67,7 @@ def upload():
|
|||||||
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
||||||
if err:
|
if err:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
@ -78,12 +78,12 @@ def web_crawl():
|
|||||||
kb_id = request.form.get("kb_id")
|
kb_id = request.form.get("kb_id")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
name = request.form.get("name")
|
name = request.form.get("name")
|
||||||
url = request.form.get("url")
|
url = request.form.get("url")
|
||||||
if not is_valid_url(url):
|
if not is_valid_url(url):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError("Can't find this knowledgebase!")
|
raise LookupError("Can't find this knowledgebase!")
|
||||||
@ -145,7 +145,7 @@ def create():
|
|||||||
kb_id = req["kb_id"]
|
kb_id = req["kb_id"]
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
@ -179,7 +179,7 @@ def list_docs():
|
|||||||
kb_id = request.args.get("kb_id")
|
kb_id = request.args.get("kb_id")
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(
|
||||||
@ -188,7 +188,7 @@ def list_docs():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
keywords = request.args.get("keywords", "")
|
keywords = request.args.get("keywords", "")
|
||||||
|
|
||||||
page_number = int(request.args.get("page", 1))
|
page_number = int(request.args.get("page", 1))
|
||||||
@ -218,7 +218,7 @@ def docinfos():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
docs = DocumentService.get_by_ids(doc_ids)
|
docs = DocumentService.get_by_ids(doc_ids)
|
||||||
return get_json_result(data=list(docs.dicts()))
|
return get_json_result(data=list(docs.dicts()))
|
||||||
@ -230,7 +230,7 @@ def thumbnails():
|
|||||||
doc_ids = request.args.get("doc_ids").split(",")
|
doc_ids = request.args.get("doc_ids").split(",")
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
docs = DocumentService.get_thumbnails(doc_ids)
|
docs = DocumentService.get_thumbnails(doc_ids)
|
||||||
@ -253,13 +253,13 @@ def change_status():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='"Status" must be either 0 or 1!',
|
message='"Status" must be either 0 or 1!',
|
||||||
code=RetCode.ARGUMENT_ERROR)
|
code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR)
|
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
@ -276,7 +276,8 @@ def change_status():
|
|||||||
message="Database error (Document update)!")
|
message="Database error (Document update)!")
|
||||||
|
|
||||||
status = int(req["status"])
|
status = int(req["status"])
|
||||||
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
|
||||||
|
search.index_name(kb.tenant_id), doc.kb_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
@ -295,7 +296,7 @@ def rm():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
root_folder = FileService.get_root_folder(current_user.id)
|
root_folder = FileService.get_root_folder(current_user.id)
|
||||||
@ -326,7 +327,7 @@ def rm():
|
|||||||
errors += str(e)
|
errors += str(e)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
@ -341,7 +342,7 @@ def run():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
for id in req["doc_ids"]:
|
for id in req["doc_ids"]:
|
||||||
@ -358,8 +359,8 @@ def run():
|
|||||||
e, doc = DocumentService.get_by_id(id)
|
e, doc = DocumentService.get_by_id(id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(message="Document not found!")
|
return get_data_error_result(message="Document not found!")
|
||||||
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
@ -383,7 +384,7 @@ def rename():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
@ -394,7 +395,7 @@ def rename():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message="The extension of file can't be changed",
|
message="The extension of file can't be changed",
|
||||||
code=RetCode.ARGUMENT_ERROR)
|
code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||||
if d.name == req["name"]:
|
if d.name == req["name"]:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
@ -450,7 +451,7 @@ def change_parser():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||||
@ -483,8 +484,8 @@ def change_parser():
|
|||||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_data_error_result(message="Tenant not found!")
|
return get_data_error_result(message="Tenant not found!")
|
||||||
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
|
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -509,13 +510,13 @@ def get_image(image_id):
|
|||||||
def upload_and_parse():
|
def upload_and_parse():
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file_objs = request.files.getlist('file')
|
file_objs = request.files.getlist('file')
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == '':
|
if file_obj.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
||||||
|
|
||||||
@ -529,7 +530,7 @@ def parse():
|
|||||||
if url:
|
if url:
|
||||||
if not is_valid_url(url):
|
if not is_valid_url(url):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
||||||
os.makedirs(download_path, exist_ok=True)
|
os.makedirs(download_path, exist_ok=True)
|
||||||
from selenium.webdriver import Chrome, ChromeOptions
|
from selenium.webdriver import Chrome, ChromeOptions
|
||||||
@ -553,7 +554,7 @@ def parse():
|
|||||||
|
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
|
|
||||||
file_objs = request.files.getlist('file')
|
file_objs = request.files.getlist('file')
|
||||||
txt = FileService.parse_docs(file_objs, current_user.id)
|
txt = FileService.parse_docs(file_objs, current_user.id)
|
||||||
|
@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
|||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.db import FileType
|
from api.db import FileType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.settings import RetCode
|
from api import settings
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ def rm():
|
|||||||
file_ids = req["file_ids"]
|
file_ids = req["file_ids"]
|
||||||
if not file_ids:
|
if not file_ids:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
try:
|
try:
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
informs = File2DocumentService.get_by_file_id(file_id)
|
informs = File2DocumentService.get_by_file_id(file_id)
|
||||||
|
@ -28,7 +28,7 @@ from api.utils import get_uuid
|
|||||||
from api.db import FileType, FileSource
|
from api.db import FileType, FileSource
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.settings import RetCode
|
from api import settings
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from api.utils.file_utils import filename_type
|
from api.utils.file_utils import filename_type
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
@ -46,13 +46,13 @@ def upload():
|
|||||||
|
|
||||||
if 'file' not in request.files:
|
if 'file' not in request.files:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
file_objs = request.files.getlist('file')
|
file_objs = request.files.getlist('file')
|
||||||
|
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == '':
|
if file_obj.filename == '':
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
file_res = []
|
file_res = []
|
||||||
try:
|
try:
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
@ -134,7 +134,7 @@ def create():
|
|||||||
try:
|
try:
|
||||||
if not FileService.is_parent_folder_exist(pf_id):
|
if not FileService.is_parent_folder_exist(pf_id):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
|
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
|
||||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Duplicated folder name in the same folder.")
|
message="Duplicated folder name in the same folder.")
|
||||||
@ -299,7 +299,7 @@ def rename():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message="The extension of file can't be changed",
|
message="The extension of file can't be changed",
|
||||||
code=RetCode.ARGUMENT_ERROR)
|
code=settings.RetCode.ARGUMENT_ERROR)
|
||||||
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||||
if file.name == req["name"]:
|
if file.name == req["name"]:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
|
@ -26,9 +26,8 @@ from api.utils import get_uuid
|
|||||||
from api.db import StatusEnum, FileSource
|
from api.db import StatusEnum, FileSource
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.db_models import File
|
from api.db.db_models import File
|
||||||
from api.settings import RetCode
|
|
||||||
from api.utils.api_utils import get_json_result
|
from api.utils.api_utils import get_json_result
|
||||||
from api.settings import docStoreConn
|
from api import settings
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
|
|
||||||
|
|
||||||
@ -68,13 +67,13 @@ def update():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if not KnowledgebaseService.query(
|
if not KnowledgebaseService.query(
|
||||||
created_by=current_user.id, id=req["kb_id"]):
|
created_by=current_user.id, id=req["kb_id"]):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||||
if not e:
|
if not e:
|
||||||
@ -113,7 +112,7 @@ def detail():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||||
code=RetCode.OPERATING_ERROR)
|
code=settings.RetCode.OPERATING_ERROR)
|
||||||
kb = KnowledgebaseService.get_detail(kb_id)
|
kb = KnowledgebaseService.get_detail(kb_id)
|
||||||
if not kb:
|
if not kb:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
@ -148,14 +147,14 @@ def rm():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message='No authorization.',
|
message='No authorization.',
|
||||||
code=RetCode.AUTHENTICATION_ERROR
|
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
kbs = KnowledgebaseService.query(
|
kbs = KnowledgebaseService.query(
|
||||||
created_by=current_user.id, id=req["kb_id"])
|
created_by=current_user.id, id=req["kb_id"])
|
||||||
if not kbs:
|
if not kbs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||||
@ -170,7 +169,7 @@ def rm():
|
|||||||
message="Database error (Knowledgebase removal)!")
|
message="Database error (Knowledgebase removal)!")
|
||||||
tenants = UserTenantService.query(user_id=current_user.id)
|
tenants = UserTenantService.query(user_id=current_user.id)
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
@ -19,7 +19,7 @@ import json
|
|||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||||
from api.settings import LIGHTEN
|
from api import settings
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db import StatusEnum, LLMType
|
from api.db import StatusEnum, LLMType
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
@ -333,7 +333,7 @@ def my_llms():
|
|||||||
@login_required
|
@login_required
|
||||||
def list_app():
|
def list_app():
|
||||||
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
||||||
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
|
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||||
model_type = request.args.get("model_type")
|
model_type = request.args.get("model_type")
|
||||||
try:
|
try:
|
||||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from flask import request
|
from flask import request
|
||||||
from api.settings import RetCode
|
from api import settings
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
@ -44,7 +44,7 @@ def create(tenant_id):
|
|||||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||||
embd_count = list(set([kb.embd_id for kb in kbs]))
|
embd_count = list(set([kb.embd_id for kb in kbs]))
|
||||||
if len(embd_count) != 1:
|
if len(embd_count) != 1:
|
||||||
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
|
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
req["kb_ids"] = ids
|
req["kb_ids"] = ids
|
||||||
# llm
|
# llm
|
||||||
llm = req.get("llm")
|
llm = req.get("llm")
|
||||||
@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
|
|||||||
if len(embd_count) != 1 :
|
if len(embd_count) != 1 :
|
||||||
return get_result(
|
return get_result(
|
||||||
message='Datasets use different embedding models."',
|
message='Datasets use different embedding models."',
|
||||||
code=RetCode.AUTHENTICATION_ERROR)
|
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||||
req["kb_ids"] = ids
|
req["kb_ids"] = ids
|
||||||
llm = req.get("llm")
|
llm = req.get("llm")
|
||||||
if llm:
|
if llm:
|
||||||
|
@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import TenantLLMService, LLMService
|
from api.db.services.llm_service import TenantLLMService, LLMService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.settings import RetCode
|
from api import settings
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import (
|
from api.utils.api_utils import (
|
||||||
get_result,
|
get_result,
|
||||||
@ -255,7 +255,7 @@ def delete(tenant_id):
|
|||||||
File2DocumentService.delete_by_document_id(doc.id)
|
File2DocumentService.delete_by_document_id(doc.id)
|
||||||
if not KnowledgebaseService.delete_by_id(id):
|
if not KnowledgebaseService.delete_by_id(id):
|
||||||
return get_error_data_result(message="Delete dataset error.(Database error)")
|
return get_error_data_result(message="Delete dataset error.(Database error)")
|
||||||
return get_result(code=RetCode.SUCCESS)
|
return get_result(code=settings.RetCode.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
||||||
@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
|
|||||||
)
|
)
|
||||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||||
return get_result(code=RetCode.SUCCESS)
|
return get_result(code=settings.RetCode.SUCCESS)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/datasets", methods=["GET"])
|
@manager.route("/datasets", methods=["GET"])
|
||||||
|
@ -18,7 +18,7 @@ from flask import request, jsonify
|
|||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler, kg_retrievaler, RetCode
|
from api import settings
|
||||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||||
|
|
||||||
|
|
||||||
@ -37,14 +37,14 @@ def retrieval(tenant_id):
|
|||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
if not e:
|
if not e:
|
||||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
||||||
|
|
||||||
if kb.tenant_id != tenant_id:
|
if kb.tenant_id != tenant_id:
|
||||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||||
ranks = retr.retrieval(
|
ranks = retr.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -72,6 +72,6 @@ def retrieval(tenant_id):
|
|||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return build_error_result(
|
return build_error_result(
|
||||||
message='No chunk found! Check the chunk status please!',
|
message='No chunk found! Check the chunk status please!',
|
||||||
code=RetCode.NOT_FOUND
|
code=settings.RetCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
|
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||||
|
@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
|
|||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
from api.db import LLMType, ParserType
|
from api.db import LLMType, ParserType
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.llm_service import TenantLLMService
|
||||||
from api.settings import kg_retrievaler
|
from api import settings
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
from api.utils.api_utils import token_required
|
from api.utils.api_utils import token_required
|
||||||
@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.settings import RetCode, retrievaler
|
from api import settings
|
||||||
from api.utils.api_utils import construct_json_result, get_parser_config
|
from api.utils.api_utils import construct_json_result, get_parser_config
|
||||||
from rag.nlp import search
|
from rag.nlp import search
|
||||||
from rag.utils import rmSpace
|
from rag.utils import rmSpace
|
||||||
from api.settings import docStoreConn
|
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL
|
from rag.utils.storage_factory import STORAGE_IMPL
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
|
|||||||
"""
|
"""
|
||||||
if "file" not in request.files:
|
if "file" not in request.files:
|
||||||
return get_error_data_result(
|
return get_error_data_result(
|
||||||
message="No file part!", code=RetCode.ARGUMENT_ERROR
|
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
|
||||||
)
|
)
|
||||||
file_objs = request.files.getlist("file")
|
file_objs = request.files.getlist("file")
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
if file_obj.filename == "":
|
if file_obj.filename == "":
|
||||||
return get_result(
|
return get_result(
|
||||||
message="No file selected!", code=RetCode.ARGUMENT_ERROR
|
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
|
||||||
)
|
)
|
||||||
# total size
|
# total size
|
||||||
total_size = 0
|
total_size = 0
|
||||||
@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
|
|||||||
if total_size > MAX_TOTAL_FILE_SIZE:
|
if total_size > MAX_TOTAL_FILE_SIZE:
|
||||||
return get_result(
|
return get_result(
|
||||||
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
||||||
code=RetCode.ARGUMENT_ERROR,
|
code=settings.RetCode.ARGUMENT_ERROR,
|
||||||
)
|
)
|
||||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||||
if not e:
|
if not e:
|
||||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||||
if err:
|
if err:
|
||||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||||
# rename key's name
|
# rename key's name
|
||||||
renamed_doc_list = []
|
renamed_doc_list = []
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -226,7 +225,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
):
|
):
|
||||||
return get_result(
|
return get_result(
|
||||||
message="The extension of file can't be changed",
|
message="The extension of file can't be changed",
|
||||||
code=RetCode.ARGUMENT_ERROR,
|
code=settings.RetCode.ARGUMENT_ERROR,
|
||||||
)
|
)
|
||||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||||
if d.name == req["name"]:
|
if d.name == req["name"]:
|
||||||
@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||||||
)
|
)
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result(message="Document not found!")
|
return get_error_data_result(message="Document not found!")
|
||||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||||
|
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
|
|||||||
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
||||||
if not file_stream:
|
if not file_stream:
|
||||||
return construct_json_result(
|
return construct_json_result(
|
||||||
message="This file is empty.", code=RetCode.DATA_ERROR
|
message="This file is empty.", code=settings.RetCode.DATA_ERROR
|
||||||
)
|
)
|
||||||
file = BytesIO(file_stream)
|
file = BytesIO(file_stream)
|
||||||
# Use send_file with a proper filename and MIME type
|
# Use send_file with a proper filename and MIME type
|
||||||
@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
|
|||||||
errors += str(e)
|
errors += str(e)
|
||||||
|
|
||||||
if errors:
|
if errors:
|
||||||
return get_result(message=errors, code=RetCode.SERVER_ERROR)
|
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||||
|
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
|
|||||||
info["chunk_num"] = 0
|
info["chunk_num"] = 0
|
||||||
info["token_num"] = 0
|
info["token_num"] = 0
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||||
TaskService.filter_delete([Task.doc_id == id])
|
TaskService.filter_delete([Task.doc_id == id])
|
||||||
e, doc = DocumentService.get_by_id(id)
|
e, doc = DocumentService.get_by_id(id)
|
||||||
doc = doc.to_dict()
|
doc = doc.to_dict()
|
||||||
@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
|
|||||||
)
|
)
|
||||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||||
DocumentService.update_by_id(id, info)
|
DocumentService.update_by_id(id, info)
|
||||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
|
|
||||||
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||||
origin_chunks = []
|
origin_chunks = []
|
||||||
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||||
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
|
||||||
|
highlight=True)
|
||||||
res["total"] = sres.total
|
res["total"] = sres.total
|
||||||
sign = 0
|
sign = 0
|
||||||
for id in sres.ids:
|
for id in sres.ids:
|
||||||
@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
|||||||
v, c = embd_mdl.encode([doc.name, req["content"]])
|
v, c = embd_mdl.encode([doc.name, req["content"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1]
|
v = 0.1 * v[0] + 0.9 * v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||||
# rename keys
|
# rename keys
|
||||||
@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
|||||||
condition = {"doc_id": document_id}
|
condition = {"doc_id": document_id}
|
||||||
if "chunk_ids" in req:
|
if "chunk_ids" in req:
|
||||||
condition["id"] = req["chunk_ids"]
|
condition["id"] = req["chunk_ids"]
|
||||||
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||||
if chunk_number != 0:
|
if chunk_number != 0:
|
||||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||||
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
||||||
@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
schema:
|
schema:
|
||||||
type: object
|
type: object
|
||||||
"""
|
"""
|
||||||
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||||
@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
|||||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
||||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||||
d["q_%d_vec" % len(v)] = v.tolist()
|
d["q_%d_vec" % len(v)] = v.tolist()
|
||||||
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||||
return get_result()
|
return get_result()
|
||||||
|
|
||||||
|
|
||||||
@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
|
|||||||
if len(embd_nms) != 1:
|
if len(embd_nms) != 1:
|
||||||
return get_result(
|
return get_result(
|
||||||
message='Datasets use different embedding models."',
|
message='Datasets use different embedding models."',
|
||||||
code=RetCode.AUTHENTICATION_ERROR,
|
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||||
)
|
)
|
||||||
if "question" not in req:
|
if "question" not in req:
|
||||||
return get_error_data_result("`question` is required.")
|
return get_error_data_result("`question` is required.")
|
||||||
@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
|
|||||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||||
ranks = retr.retrieval(
|
ranks = retr.retrieval(
|
||||||
question,
|
question,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
|
|||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_result(
|
return get_result(
|
||||||
message="No chunk found! Check the chunk status please!",
|
message="No chunk found! Check the chunk status please!",
|
||||||
code=RetCode.DATA_ERROR,
|
code=settings.RetCode.DATA_ERROR,
|
||||||
)
|
)
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
@ -22,7 +22,7 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services.api_service import APITokenService
|
from api.db.services.api_service import APITokenService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.settings import DATABASE_TYPE
|
from api import settings
|
||||||
from api.utils import current_timestamp, datetime_format
|
from api.utils import current_timestamp, datetime_format
|
||||||
from api.utils.api_utils import (
|
from api.utils.api_utils import (
|
||||||
get_json_result,
|
get_json_result,
|
||||||
@ -31,7 +31,6 @@ from api.utils.api_utils import (
|
|||||||
generate_confirmation_token,
|
generate_confirmation_token,
|
||||||
)
|
)
|
||||||
from api.versions import get_ragflow_version
|
from api.versions import get_ragflow_version
|
||||||
from api.settings import docStoreConn
|
|
||||||
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
@ -98,7 +97,7 @@ def status():
|
|||||||
res = {}
|
res = {}
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
res["doc_store"] = docStoreConn.health()
|
res["doc_store"] = settings.docStoreConn.health()
|
||||||
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
res["doc_store"] = {
|
res["doc_store"] = {
|
||||||
@ -128,13 +127,13 @@ def status():
|
|||||||
try:
|
try:
|
||||||
KnowledgebaseService.get_by_id("x")
|
KnowledgebaseService.get_by_id("x")
|
||||||
res["database"] = {
|
res["database"] = {
|
||||||
"database": DATABASE_TYPE.lower(),
|
"database": settings.DATABASE_TYPE.lower(),
|
||||||
"status": "green",
|
"status": "green",
|
||||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
res["database"] = {
|
res["database"] = {
|
||||||
"database": DATABASE_TYPE.lower(),
|
"database": settings.DATABASE_TYPE.lower(),
|
||||||
"status": "red",
|
"status": "red",
|
||||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
@ -38,20 +38,7 @@ from api.utils import (
|
|||||||
datetime_format,
|
datetime_format,
|
||||||
)
|
)
|
||||||
from api.db import UserTenantRole, FileType
|
from api.db import UserTenantRole, FileType
|
||||||
from api.settings import (
|
from api import settings
|
||||||
RetCode,
|
|
||||||
GITHUB_OAUTH,
|
|
||||||
FEISHU_OAUTH,
|
|
||||||
CHAT_MDL,
|
|
||||||
EMBEDDING_MDL,
|
|
||||||
ASR_MDL,
|
|
||||||
IMAGE2TEXT_MDL,
|
|
||||||
PARSERS,
|
|
||||||
API_KEY,
|
|
||||||
LLM_FACTORY,
|
|
||||||
LLM_BASE_URL,
|
|
||||||
RERANK_MDL,
|
|
||||||
)
|
|
||||||
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.utils.api_utils import get_json_result, construct_response
|
from api.utils.api_utils import get_json_result, construct_response
|
||||||
@ -90,7 +77,7 @@ def login():
|
|||||||
"""
|
"""
|
||||||
if not request.json:
|
if not request.json:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
||||||
)
|
)
|
||||||
|
|
||||||
email = request.json.get("email", "")
|
email = request.json.get("email", "")
|
||||||
@ -98,7 +85,7 @@ def login():
|
|||||||
if not users:
|
if not users:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
code=RetCode.AUTHENTICATION_ERROR,
|
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||||
message=f"Email: {email} is not registered!",
|
message=f"Email: {email} is not registered!",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -107,7 +94,7 @@ def login():
|
|||||||
password = decrypt(password)
|
password = decrypt(password)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
|
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
|
||||||
)
|
)
|
||||||
|
|
||||||
user = UserService.query_user(email, password)
|
user = UserService.query_user(email, password)
|
||||||
@ -123,7 +110,7 @@ def login():
|
|||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
code=RetCode.AUTHENTICATION_ERROR,
|
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||||
message="Email and password do not match!",
|
message="Email and password do not match!",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -150,10 +137,10 @@ def github_callback():
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
GITHUB_OAUTH.get("url"),
|
settings.GITHUB_OAUTH.get("url"),
|
||||||
data={
|
data={
|
||||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
||||||
"code": request.args.get("code"),
|
"code": request.args.get("code"),
|
||||||
},
|
},
|
||||||
headers={"Accept": "application/json"},
|
headers={"Accept": "application/json"},
|
||||||
@ -235,11 +222,11 @@ def feishu_callback():
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
app_access_token_res = requests.post(
|
app_access_token_res = requests.post(
|
||||||
FEISHU_OAUTH.get("app_access_token_url"),
|
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
"app_id": FEISHU_OAUTH.get("app_id"),
|
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
||||||
"app_secret": FEISHU_OAUTH.get("app_secret"),
|
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
@ -249,10 +236,10 @@ def feishu_callback():
|
|||||||
return redirect("/?error=%s" % app_access_token_res)
|
return redirect("/?error=%s" % app_access_token_res)
|
||||||
|
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
FEISHU_OAUTH.get("user_access_token_url"),
|
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||||
data=json.dumps(
|
data=json.dumps(
|
||||||
{
|
{
|
||||||
"grant_type": FEISHU_OAUTH.get("grant_type"),
|
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
||||||
"code": request.args.get("code"),
|
"code": request.args.get("code"),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
@ -409,7 +396,7 @@ def setting_user():
|
|||||||
):
|
):
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
code=RetCode.AUTHENTICATION_ERROR,
|
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||||
message="Password error!",
|
message="Password error!",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -438,7 +425,7 @@ def setting_user():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
|
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -497,12 +484,12 @@ def user_register(user_id, user):
|
|||||||
tenant = {
|
tenant = {
|
||||||
"id": user_id,
|
"id": user_id,
|
||||||
"name": user["nickname"] + "‘s Kingdom",
|
"name": user["nickname"] + "‘s Kingdom",
|
||||||
"llm_id": CHAT_MDL,
|
"llm_id": settings.CHAT_MDL,
|
||||||
"embd_id": EMBEDDING_MDL,
|
"embd_id": settings.EMBEDDING_MDL,
|
||||||
"asr_id": ASR_MDL,
|
"asr_id": settings.ASR_MDL,
|
||||||
"parser_ids": PARSERS,
|
"parser_ids": settings.PARSERS,
|
||||||
"img2txt_id": IMAGE2TEXT_MDL,
|
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||||
"rerank_id": RERANK_MDL,
|
"rerank_id": settings.RERANK_MDL,
|
||||||
}
|
}
|
||||||
usr_tenant = {
|
usr_tenant = {
|
||||||
"tenant_id": user_id,
|
"tenant_id": user_id,
|
||||||
@ -522,15 +509,15 @@ def user_register(user_id, user):
|
|||||||
"location": "",
|
"location": "",
|
||||||
}
|
}
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
||||||
tenant_llm.append(
|
tenant_llm.append(
|
||||||
{
|
{
|
||||||
"tenant_id": user_id,
|
"tenant_id": user_id,
|
||||||
"llm_factory": LLM_FACTORY,
|
"llm_factory": settings.LLM_FACTORY,
|
||||||
"llm_name": llm.llm_name,
|
"llm_name": llm.llm_name,
|
||||||
"model_type": llm.model_type,
|
"model_type": llm.model_type,
|
||||||
"api_key": API_KEY,
|
"api_key": settings.API_KEY,
|
||||||
"api_base": LLM_BASE_URL,
|
"api_base": settings.LLM_BASE_URL,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -582,7 +569,7 @@ def user_add():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message=f"Invalid email address: {email_address}!",
|
message=f"Invalid email address: {email_address}!",
|
||||||
code=RetCode.OPERATING_ERROR,
|
code=settings.RetCode.OPERATING_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the email address is already used
|
# Check if the email address is already used
|
||||||
@ -590,7 +577,7 @@ def user_add():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message=f"Email: {email_address} has already registered!",
|
message=f"Email: {email_address} has already registered!",
|
||||||
code=RetCode.OPERATING_ERROR,
|
code=settings.RetCode.OPERATING_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct user info data
|
# Construct user info data
|
||||||
@ -625,7 +612,7 @@ def user_add():
|
|||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False,
|
data=False,
|
||||||
message=f"User registration failure, error: {str(e)}",
|
message=f"User registration failure, error: {str(e)}",
|
||||||
code=RetCode.EXCEPTION_ERROR,
|
code=settings.RetCode.EXCEPTION_ERROR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ from peewee import (
|
|||||||
)
|
)
|
||||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||||
from api.db import SerializedType, ParserType
|
from api.db import SerializedType, ParserType
|
||||||
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
|
from api import settings
|
||||||
from api import utils
|
from api import utils
|
||||||
|
|
||||||
def singleton(cls, *args, **kw):
|
def singleton(cls, *args, **kw):
|
||||||
@ -62,7 +62,7 @@ class TextFieldType(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class LongTextField(TextField):
|
class LongTextField(TextField):
|
||||||
field_type = TextFieldType[DATABASE_TYPE.upper()].value
|
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
|
||||||
|
|
||||||
|
|
||||||
class JSONField(LongTextField):
|
class JSONField(LongTextField):
|
||||||
@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
|
|||||||
@singleton
|
@singleton
|
||||||
class BaseDataBase:
|
class BaseDataBase:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
database_config = DATABASE.copy()
|
database_config = settings.DATABASE.copy()
|
||||||
db_name = database_config.pop("name")
|
db_name = database_config.pop("name")
|
||||||
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
|
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||||
logging.info('init database on cluster mode successfully')
|
logging.info('init database on cluster mode successfully')
|
||||||
|
|
||||||
class PostgresDatabaseLock:
|
class PostgresDatabaseLock:
|
||||||
@ -385,7 +385,7 @@ class DatabaseLock(Enum):
|
|||||||
|
|
||||||
|
|
||||||
DB = BaseDataBase().database_connection
|
DB = BaseDataBase().database_connection
|
||||||
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
|
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
|
||||||
|
|
||||||
|
|
||||||
def close_connection():
|
def close_connection():
|
||||||
@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
|
|||||||
return self.email
|
return self.email
|
||||||
|
|
||||||
def get_id(self):
|
def get_id(self):
|
||||||
jwt = Serializer(secret_key=SECRET_KEY)
|
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||||
return jwt.dumps(str(self.access_token))
|
return jwt.dumps(str(self.access_token))
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
|
|||||||
|
|
||||||
def migrate_db():
|
def migrate_db():
|
||||||
with DB.transaction():
|
with DB.transaction():
|
||||||
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
|
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||||
try:
|
try:
|
||||||
migrate(
|
migrate(
|
||||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
||||||
|
@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
|
||||||
|
|
||||||
@ -51,11 +51,11 @@ def init_superuser():
|
|||||||
tenant = {
|
tenant = {
|
||||||
"id": user_info["id"],
|
"id": user_info["id"],
|
||||||
"name": user_info["nickname"] + "‘s Kingdom",
|
"name": user_info["nickname"] + "‘s Kingdom",
|
||||||
"llm_id": CHAT_MDL,
|
"llm_id": settings.CHAT_MDL,
|
||||||
"embd_id": EMBEDDING_MDL,
|
"embd_id": settings.EMBEDDING_MDL,
|
||||||
"asr_id": ASR_MDL,
|
"asr_id": settings.ASR_MDL,
|
||||||
"parser_ids": PARSERS,
|
"parser_ids": settings.PARSERS,
|
||||||
"img2txt_id": IMAGE2TEXT_MDL
|
"img2txt_id": settings.IMAGE2TEXT_MDL
|
||||||
}
|
}
|
||||||
usr_tenant = {
|
usr_tenant = {
|
||||||
"tenant_id": user_info["id"],
|
"tenant_id": user_info["id"],
|
||||||
@ -64,10 +64,11 @@ def init_superuser():
|
|||||||
"role": UserTenantRole.OWNER
|
"role": UserTenantRole.OWNER
|
||||||
}
|
}
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
||||||
tenant_llm.append(
|
tenant_llm.append(
|
||||||
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
|
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
|
||||||
"api_key": API_KEY, "api_base": LLM_BASE_URL})
|
"model_type": llm.model_type,
|
||||||
|
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
|
||||||
|
|
||||||
if not UserService.save(**user_info):
|
if not UserService.save(**user_info):
|
||||||
logging.error("can't init admin.")
|
logging.error("can't init admin.")
|
||||||
|
@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
|
|||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||||
from api.settings import retrievaler, kg_retrievaler
|
from api import settings
|
||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||||
@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||||
|
|
||||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||||
retr = retrievaler if not is_kg else kg_retrievaler
|
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||||
|
|
||||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||||||
@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|||||||
|
|
||||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||||
tried_times += 1
|
tried_times += 1
|
||||||
return retrievaler.sql_retrieval(sql, format="json"), sql
|
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||||
|
|
||||||
tbl, sql = get_table()
|
tbl, sql = get_table()
|
||||||
if tbl is None:
|
if tbl is None:
|
||||||
@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
|
|||||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||||
|
|
||||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||||
retr = retrievaler if not is_kg else kg_retrievaler
|
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||||
|
|
||||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||||
|
@ -26,7 +26,7 @@ from io import BytesIO
|
|||||||
from peewee import fn
|
from peewee import fn
|
||||||
|
|
||||||
from api.db.db_utils import bulk_insert_into_db
|
from api.db.db_utils import bulk_insert_into_db
|
||||||
from api.settings import docStoreConn
|
from api import settings
|
||||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||||
from graphrag.mind_map_extractor import MindMapExtractor
|
from graphrag.mind_map_extractor import MindMapExtractor
|
||||||
from rag.settings import SVR_QUEUE_NAME
|
from rag.settings import SVR_QUEUE_NAME
|
||||||
@ -108,7 +108,7 @@ class DocumentService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def remove_document(cls, doc, tenant_id):
|
def remove_document(cls, doc, tenant_id):
|
||||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||||
cls.clear_chunk_num(doc.id)
|
cls.clear_chunk_num(doc.id)
|
||||||
return cls.delete_by_id(doc.id)
|
return cls.delete_by_id(doc.id)
|
||||||
|
|
||||||
@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||||||
d["q_%d_vec" % len(v)] = v
|
d["q_%d_vec" % len(v)] = v
|
||||||
for b in range(0, len(cks), es_bulk_size):
|
for b in range(0, len(cks), es_bulk_size):
|
||||||
if try_create_idx:
|
if try_create_idx:
|
||||||
if not docStoreConn.indexExist(idxnm, kb_id):
|
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||||
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||||
try_create_idx = False
|
try_create_idx = False
|
||||||
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||||
|
|
||||||
DocumentService.increment_chunk_num(
|
DocumentService.increment_chunk_num(
|
||||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||||
|
@ -33,12 +33,10 @@ import traceback
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from werkzeug.serving import run_simple
|
from werkzeug.serving import run_simple
|
||||||
|
from api import settings
|
||||||
from api.apps import app
|
from api.apps import app
|
||||||
from api.db.runtime_config import RuntimeConfig
|
from api.db.runtime_config import RuntimeConfig
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.settings import (
|
|
||||||
HOST, HTTP_PORT
|
|
||||||
)
|
|
||||||
from api import utils
|
from api import utils
|
||||||
|
|
||||||
from api.db.db_models import init_database_tables as init_web_db
|
from api.db.db_models import init_database_tables as init_web_db
|
||||||
@ -72,6 +70,7 @@ if __name__ == '__main__':
|
|||||||
f'project base: {utils.file_utils.get_project_base_directory()}'
|
f'project base: {utils.file_utils.get_project_base_directory()}'
|
||||||
)
|
)
|
||||||
show_configs()
|
show_configs()
|
||||||
|
settings.init_settings()
|
||||||
|
|
||||||
# init db
|
# init db
|
||||||
init_web_db()
|
init_web_db()
|
||||||
@ -96,7 +95,7 @@ if __name__ == '__main__':
|
|||||||
logging.info("run on debug mode")
|
logging.info("run on debug mode")
|
||||||
|
|
||||||
RuntimeConfig.init_env()
|
RuntimeConfig.init_env()
|
||||||
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
|
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
|
||||||
|
|
||||||
thread = ThreadPoolExecutor(max_workers=1)
|
thread = ThreadPoolExecutor(max_workers=1)
|
||||||
thread.submit(update_progress)
|
thread.submit(update_progress)
|
||||||
@ -105,8 +104,8 @@ if __name__ == '__main__':
|
|||||||
try:
|
try:
|
||||||
logging.info("RAGFlow HTTP server start...")
|
logging.info("RAGFlow HTTP server start...")
|
||||||
run_simple(
|
run_simple(
|
||||||
hostname=HOST,
|
hostname=settings.HOST_IP,
|
||||||
port=HTTP_PORT,
|
port=settings.HOST_PORT,
|
||||||
application=app,
|
application=app,
|
||||||
threaded=True,
|
threaded=True,
|
||||||
use_reloader=RuntimeConfig.DEBUG,
|
use_reloader=RuntimeConfig.DEBUG,
|
||||||
|
@ -30,12 +30,46 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
|||||||
|
|
||||||
REQUEST_WAIT_SEC = 2
|
REQUEST_WAIT_SEC = 2
|
||||||
REQUEST_MAX_WAIT_SEC = 300
|
REQUEST_MAX_WAIT_SEC = 300
|
||||||
|
LLM = None
|
||||||
|
LLM_FACTORY = None
|
||||||
|
LLM_BASE_URL = None
|
||||||
|
CHAT_MDL = ""
|
||||||
|
EMBEDDING_MDL = ""
|
||||||
|
RERANK_MDL = ""
|
||||||
|
ASR_MDL = ""
|
||||||
|
IMAGE2TEXT_MDL = ""
|
||||||
|
API_KEY = None
|
||||||
|
PARSERS = None
|
||||||
|
HOST_IP = None
|
||||||
|
HOST_PORT = None
|
||||||
|
SECRET_KEY = None
|
||||||
|
|
||||||
|
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
||||||
|
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||||
|
|
||||||
|
# authentication
|
||||||
|
AUTHENTICATION_CONF = None
|
||||||
|
|
||||||
|
# client
|
||||||
|
CLIENT_AUTHENTICATION = None
|
||||||
|
HTTP_APP_KEY = None
|
||||||
|
GITHUB_OAUTH = None
|
||||||
|
FEISHU_OAUTH = None
|
||||||
|
|
||||||
|
DOC_ENGINE = None
|
||||||
|
docStoreConn = None
|
||||||
|
|
||||||
|
retrievaler = None
|
||||||
|
kg_retrievaler = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_settings():
|
||||||
|
global LLM, LLM_FACTORY, LLM_BASE_URL
|
||||||
LLM = get_base_config("user_default_llm", {})
|
LLM = get_base_config("user_default_llm", {})
|
||||||
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
||||||
LLM_BASE_URL = LLM.get("base_url")
|
LLM_BASE_URL = LLM.get("base_url")
|
||||||
|
|
||||||
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
|
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||||
if not LIGHTEN:
|
if not LIGHTEN:
|
||||||
default_llm = {
|
default_llm = {
|
||||||
"Tongyi-Qianwen": {
|
"Tongyi-Qianwen": {
|
||||||
@ -102,21 +136,20 @@ if not LIGHTEN:
|
|||||||
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
||||||
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
||||||
|
|
||||||
|
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||||
API_KEY = LLM.get("api_key", "")
|
API_KEY = LLM.get("api_key", "")
|
||||||
PARSERS = LLM.get(
|
PARSERS = LLM.get(
|
||||||
"parsers",
|
"parsers",
|
||||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
||||||
|
|
||||||
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||||
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||||
|
|
||||||
SECRET_KEY = get_base_config(
|
SECRET_KEY = get_base_config(
|
||||||
RAG_FLOW_SERVICE_NAME,
|
RAG_FLOW_SERVICE_NAME,
|
||||||
{}).get("secret_key", str(date.today()))
|
{}).get("secret_key", str(date.today()))
|
||||||
|
|
||||||
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
|
||||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
|
||||||
|
|
||||||
# authentication
|
# authentication
|
||||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||||
|
|
||||||
@ -128,6 +161,7 @@ HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
|||||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||||
|
|
||||||
|
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
||||||
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
||||||
if DOC_ENGINE == "elasticsearch":
|
if DOC_ENGINE == "elasticsearch":
|
||||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||||
@ -139,6 +173,15 @@ else:
|
|||||||
retrievaler = search.Dealer(docStoreConn)
|
retrievaler = search.Dealer(docStoreConn)
|
||||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||||
|
|
||||||
|
def get_host_ip():
|
||||||
|
global HOST_IP
|
||||||
|
return HOST_IP
|
||||||
|
|
||||||
|
|
||||||
|
def get_host_port():
|
||||||
|
global HOST_PORT
|
||||||
|
return HOST_PORT
|
||||||
|
|
||||||
|
|
||||||
class CustomEnum(Enum):
|
class CustomEnum(Enum):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
|
|||||||
from werkzeug.http import HTTP_STATUS_CODES
|
from werkzeug.http import HTTP_STATUS_CODES
|
||||||
|
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.settings import (
|
from api import settings
|
||||||
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
|
||||||
CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
from api import settings
|
||||||
)
|
|
||||||
from api.settings import RetCode
|
|
||||||
from api.utils import CustomJSONEncoder, get_uuid
|
from api.utils import CustomJSONEncoder, get_uuid
|
||||||
from api.utils import json_dumps
|
from api.utils import json_dumps
|
||||||
|
|
||||||
@ -59,13 +57,13 @@ def request(**kwargs):
|
|||||||
{}).items()}
|
{}).items()}
|
||||||
prepped = requests.Request(**kwargs).prepare()
|
prepped = requests.Request(**kwargs).prepare()
|
||||||
|
|
||||||
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
|
||||||
timestamp = str(round(time() * 1000))
|
timestamp = str(round(time() * 1000))
|
||||||
nonce = str(uuid1())
|
nonce = str(uuid1())
|
||||||
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
|
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
|
||||||
timestamp.encode('ascii'),
|
timestamp.encode('ascii'),
|
||||||
nonce.encode('ascii'),
|
nonce.encode('ascii'),
|
||||||
HTTP_APP_KEY.encode('ascii'),
|
settings.HTTP_APP_KEY.encode('ascii'),
|
||||||
prepped.path_url.encode('ascii'),
|
prepped.path_url.encode('ascii'),
|
||||||
prepped.body if kwargs.get('json') else b'',
|
prepped.body if kwargs.get('json') else b'',
|
||||||
urlencode(
|
urlencode(
|
||||||
@ -79,7 +77,7 @@ def request(**kwargs):
|
|||||||
prepped.headers.update({
|
prepped.headers.update({
|
||||||
'TIMESTAMP': timestamp,
|
'TIMESTAMP': timestamp,
|
||||||
'NONCE': nonce,
|
'NONCE': nonce,
|
||||||
'APP-KEY': HTTP_APP_KEY,
|
'APP-KEY': settings.HTTP_APP_KEY,
|
||||||
'SIGNATURE': signature,
|
'SIGNATURE': signature,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -89,7 +87,7 @@ def request(**kwargs):
|
|||||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||||
"""Calculate the exponential backoff wait time."""
|
"""Calculate the exponential backoff wait time."""
|
||||||
# Will be zero if factor equals 0
|
# Will be zero if factor equals 0
|
||||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
|
||||||
# Full jitter according to
|
# Full jitter according to
|
||||||
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||||
if full_jitter:
|
if full_jitter:
|
||||||
@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
|||||||
return max(0, countdown)
|
return max(0, countdown)
|
||||||
|
|
||||||
|
|
||||||
def get_data_error_result(code=RetCode.DATA_ERROR,
|
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
|
||||||
message='Sorry! Data missing!'):
|
message='Sorry! Data missing!'):
|
||||||
import re
|
import re
|
||||||
result_dict = {
|
result_dict = {
|
||||||
@ -126,8 +124,8 @@ def server_error_response(e):
|
|||||||
pass
|
pass
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||||
|
|
||||||
|
|
||||||
def error_response(response_code, message=None):
|
def error_response(response_code, message=None):
|
||||||
@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
|
|||||||
error_string += "required argument values: {}".format(
|
error_string += "required argument values: {}".format(
|
||||||
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
code=RetCode.ARGUMENT_ERROR, message=error_string)
|
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
|
||||||
return func(*_args, **_kwargs)
|
return func(*_args, **_kwargs)
|
||||||
|
|
||||||
return decorated_function
|
return decorated_function
|
||||||
@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
|
|||||||
return send_file(f, as_attachment=True, attachment_filename=filename)
|
return send_file(f, as_attachment=True, attachment_filename=filename)
|
||||||
|
|
||||||
|
|
||||||
def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
||||||
response = {"code": code, "message": message, "data": data}
|
response = {"code": code, "message": message, "data": data}
|
||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
@ -204,7 +202,7 @@ def apikey_required(func):
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return build_error_result(
|
return build_error_result(
|
||||||
message='API-KEY is invalid!', code=RetCode.FORBIDDEN
|
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
|
||||||
)
|
)
|
||||||
kwargs['tenant_id'] = objs[0].tenant_id
|
kwargs['tenant_id'] = objs[0].tenant_id
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -212,14 +210,14 @@ def apikey_required(func):
|
|||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
def build_error_result(code=RetCode.FORBIDDEN, message='success'):
|
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
|
||||||
response = {"code": code, "message": message}
|
response = {"code": code, "message": message}
|
||||||
response = jsonify(response)
|
response = jsonify(response)
|
||||||
response.status_code = code
|
response.status_code = code
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def construct_response(code=RetCode.SUCCESS,
|
def construct_response(code=settings.RetCode.SUCCESS,
|
||||||
message='success', data=None, auth=None):
|
message='success', data=None, auth=None):
|
||||||
result_dict = {"code": code, "message": message, "data": data}
|
result_dict = {"code": code, "message": message, "data": data}
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
|
||||||
import re
|
import re
|
||||||
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
||||||
response = {}
|
response = {}
|
||||||
@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
|||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
||||||
if data is None:
|
if data is None:
|
||||||
return jsonify({"code": code, "message": message})
|
return jsonify({"code": code, "message": message})
|
||||||
else:
|
else:
|
||||||
@ -262,12 +260,12 @@ def construct_error_response(e):
|
|||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
try:
|
try:
|
||||||
if e.code == 401:
|
if e.code == 401:
|
||||||
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
|
||||||
except BaseException:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||||
|
|
||||||
|
|
||||||
def token_required(func):
|
def token_required(func):
|
||||||
@ -280,7 +278,7 @@ def token_required(func):
|
|||||||
objs = APIToken.query(token=token)
|
objs = APIToken.query(token=token)
|
||||||
if not objs:
|
if not objs:
|
||||||
return get_json_result(
|
return get_json_result(
|
||||||
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
|
data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
|
||||||
)
|
)
|
||||||
kwargs['tenant_id'] = objs[0].tenant_id
|
kwargs['tenant_id'] = objs[0].tenant_id
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -288,7 +286,7 @@ def token_required(func):
|
|||||||
return decorated_function
|
return decorated_function
|
||||||
|
|
||||||
|
|
||||||
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
|
||||||
if code == 0:
|
if code == 0:
|
||||||
if data is not None:
|
if data is not None:
|
||||||
response = {"code": code, "data": data}
|
response = {"code": code, "data": data}
|
||||||
@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
|
|||||||
return jsonify(response)
|
return jsonify(response)
|
||||||
|
|
||||||
|
|
||||||
def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
|
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
|
||||||
):
|
):
|
||||||
import re
|
import re
|
||||||
result_dict = {
|
result_dict = {
|
||||||
|
@ -24,7 +24,7 @@ import numpy as np
|
|||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from pypdf import PdfReader as pdf2_read
|
from pypdf import PdfReader as pdf2_read
|
||||||
|
|
||||||
from api.settings import LIGHTEN
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
@ -41,7 +41,7 @@ class RAGFlowPdfParser:
|
|||||||
self.tbl_det = TableStructureRecognizer()
|
self.tbl_det = TableStructureRecognizer()
|
||||||
|
|
||||||
self.updown_cnt_mdl = xgb.Booster()
|
self.updown_cnt_mdl = xgb.Booster()
|
||||||
if not LIGHTEN:
|
if not settings.LIGHTEN:
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -252,13 +252,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api import settings
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||||
|
|
||||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||||
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||||
info = {
|
info = {
|
||||||
"input_text": docs,
|
"input_text": docs,
|
||||||
"entity_specs": "organization, person",
|
"entity_specs": "organization, person",
|
||||||
|
@ -30,14 +30,14 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.settings import retrievaler
|
from api import settings
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
|
||||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||||
|
|
||||||
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||||
docs = [d["content_with_weight"] for d in
|
docs = [d["content_with_weight"] for d in
|
||||||
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
||||||
graph = ex(docs)
|
graph = ex(docs)
|
||||||
|
|
||||||
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||||
|
@ -23,7 +23,7 @@ from collections import defaultdict
|
|||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.settings import retrievaler, docStoreConn
|
from api import settings
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from rag.nlp import tokenize, search
|
from rag.nlp import tokenize, search
|
||||||
from ranx import evaluate
|
from ranx import evaluate
|
||||||
@ -52,7 +52,7 @@ class Benchmark:
|
|||||||
run = defaultdict(dict)
|
run = defaultdict(dict)
|
||||||
query_list = list(qrels.keys())
|
query_list = list(qrels.keys())
|
||||||
for query in query_list:
|
for query in query_list:
|
||||||
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||||
0.0, self.vector_similarity_weight)
|
0.0, self.vector_similarity_weight)
|
||||||
if len(ranks["chunks"]) == 0:
|
if len(ranks["chunks"]) == 0:
|
||||||
print(f"deleted query: {query}")
|
print(f"deleted query: {query}")
|
||||||
@ -81,9 +81,9 @@ class Benchmark:
|
|||||||
def init_index(self, vector_size: int):
|
def init_index(self, vector_size: int):
|
||||||
if self.initialized_index:
|
if self.initialized_index:
|
||||||
return
|
return
|
||||||
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||||
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||||
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||||
self.initialized_index = True
|
self.initialized_index = True
|
||||||
|
|
||||||
def ms_marco_index(self, file_path, index_name):
|
def ms_marco_index(self, file_path, index_name):
|
||||||
@ -118,13 +118,13 @@ class Benchmark:
|
|||||||
docs_count += len(docs)
|
docs_count += len(docs)
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
if docs:
|
if docs:
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||||
return qrels, texts
|
return qrels, texts
|
||||||
|
|
||||||
def trivia_qa_index(self, file_path, index_name):
|
def trivia_qa_index(self, file_path, index_name):
|
||||||
@ -159,12 +159,12 @@ class Benchmark:
|
|||||||
docs_count += len(docs)
|
docs_count += len(docs)
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
docStoreConn.insert(docs,self.index_name)
|
settings.docStoreConn.insert(docs,self.index_name)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
docStoreConn.insert(docs, self.index_name)
|
settings.docStoreConn.insert(docs, self.index_name)
|
||||||
return qrels, texts
|
return qrels, texts
|
||||||
|
|
||||||
def miracl_index(self, file_path, corpus_path, index_name):
|
def miracl_index(self, file_path, corpus_path, index_name):
|
||||||
@ -214,12 +214,12 @@ class Benchmark:
|
|||||||
docs_count += len(docs)
|
docs_count += len(docs)
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
docStoreConn.insert(docs, self.index_name)
|
settings.docStoreConn.insert(docs, self.index_name)
|
||||||
docs = []
|
docs = []
|
||||||
|
|
||||||
docs, vector_size = self.embedding(docs)
|
docs, vector_size = self.embedding(docs)
|
||||||
self.init_index(vector_size)
|
self.init_index(vector_size)
|
||||||
docStoreConn.insert(docs, self.index_name)
|
settings.docStoreConn.insert(docs, self.index_name)
|
||||||
return qrels, texts
|
return qrels, texts
|
||||||
|
|
||||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||||
|
@ -28,7 +28,7 @@ from openai import OpenAI
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from api.settings import LIGHTEN
|
from api import settings
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
|
|||||||
^_-
|
^_-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not LIGHTEN and not DefaultEmbedding._model:
|
if not settings.LIGHTEN and not DefaultEmbedding._model:
|
||||||
with DefaultEmbedding._model_lock:
|
with DefaultEmbedding._model_lock:
|
||||||
from FlagEmbedding import FlagModel
|
from FlagEmbedding import FlagModel
|
||||||
import torch
|
import torch
|
||||||
@ -248,7 +248,7 @@ class FastEmbed(Base):
|
|||||||
threads: Optional[int] = None,
|
threads: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if not LIGHTEN and not FastEmbed._model:
|
if not settings.LIGHTEN and not FastEmbed._model:
|
||||||
from fastembed import TextEmbedding
|
from fastembed import TextEmbedding
|
||||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||||
|
|
||||||
@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
|
|||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||||
if not LIGHTEN and not YoudaoEmbed._client:
|
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||||
from BCEmbedding import EmbeddingModel as qanthing
|
from BCEmbedding import EmbeddingModel as qanthing
|
||||||
try:
|
try:
|
||||||
logging.info("LOADING BCE...")
|
logging.info("LOADING BCE...")
|
||||||
|
@ -23,7 +23,7 @@ import os
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from api.settings import LIGHTEN
|
from api import settings
|
||||||
from api.utils.file_utils import get_home_cache_dir
|
from api.utils.file_utils import get_home_cache_dir
|
||||||
from rag.utils import num_tokens_from_string, truncate
|
from rag.utils import num_tokens_from_string, truncate
|
||||||
import json
|
import json
|
||||||
@ -57,7 +57,7 @@ class DefaultRerank(Base):
|
|||||||
^_-
|
^_-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not LIGHTEN and not DefaultRerank._model:
|
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||||
import torch
|
import torch
|
||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
with DefaultRerank._model_lock:
|
with DefaultRerank._model_lock:
|
||||||
@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
|
|||||||
_model_lock = threading.Lock()
|
_model_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||||
if not LIGHTEN and not YoudaoRerank._model:
|
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||||
from BCEmbedding import RerankerModel
|
from BCEmbedding import RerankerModel
|
||||||
with YoudaoRerank._model_lock:
|
with YoudaoRerank._model_lock:
|
||||||
if not YoudaoRerank._model:
|
if not YoudaoRerank._model:
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from api.utils.log_utils import initRootLogger
|
from api.utils.log_utils import initRootLogger
|
||||||
|
|
||||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||||
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
||||||
for module in ["pdfminer"]:
|
for module in ["pdfminer"]:
|
||||||
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.task_service import TaskService
|
from api.db.services.task_service import TaskService
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.settings import retrievaler, docStoreConn
|
from api import settings
|
||||||
from api.db.db_models import close_connection
|
from api.db.db_models import close_connection
|
||||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||||
|
knowledge_graph, email
|
||||||
from rag.nlp import search, rag_tokenizer
|
from rag.nlp import search, rag_tokenizer
|
||||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
||||||
@ -88,6 +90,7 @@ PENDING_TASKS = 0
|
|||||||
HEAD_CREATED_AT = ""
|
HEAD_CREATED_AT = ""
|
||||||
HEAD_DETAIL = ""
|
HEAD_DETAIL = ""
|
||||||
|
|
||||||
|
|
||||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||||
global PAYLOAD
|
global PAYLOAD
|
||||||
if prog is not None and prog < 0:
|
if prog is not None and prog < 0:
|
||||||
@ -171,7 +174,8 @@ def build(row):
|
|||||||
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||||
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
logging.exception(
|
||||||
|
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if re.search("(No such file|not found)", str(e)):
|
if re.search("(No such file|not found)", str(e)):
|
||||||
@ -226,7 +230,8 @@ def build(row):
|
|||||||
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
||||||
el += timer() - st
|
el += timer() - st
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
logging.exception(
|
||||||
|
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
||||||
|
|
||||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
||||||
del d["image"]
|
del d["image"]
|
||||||
@ -262,7 +267,7 @@ def build(row):
|
|||||||
|
|
||||||
def init_kb(row, vector_size: int):
|
def init_kb(row, vector_size: int):
|
||||||
idxnm = search.index_name(row["tenant_id"])
|
idxnm = search.index_name(row["tenant_id"])
|
||||||
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||||
|
|
||||||
|
|
||||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||||
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|||||||
vector_size = len(vts[0])
|
vector_size = len(vts[0])
|
||||||
vctr_nm = "q_%d_vec" % vector_size
|
vctr_nm = "q_%d_vec" % vector_size
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
|
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||||
|
fields=["content_with_weight", vctr_nm]):
|
||||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||||
|
|
||||||
raptor = Raptor(
|
raptor = Raptor(
|
||||||
@ -384,7 +390,8 @@ def main():
|
|||||||
# TODO: exception handler
|
# TODO: exception handler
|
||||||
## set_progress(r["did"], -1, "ERROR: ")
|
## set_progress(r["did"], -1, "ERROR: ")
|
||||||
callback(
|
callback(
|
||||||
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
|
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
||||||
|
timer() - st)
|
||||||
)
|
)
|
||||||
st = timer()
|
st = timer()
|
||||||
try:
|
try:
|
||||||
@ -403,18 +410,18 @@ def main():
|
|||||||
es_r = ""
|
es_r = ""
|
||||||
es_bulk_size = 4
|
es_bulk_size = 4
|
||||||
for b in range(0, len(cks), es_bulk_size):
|
for b in range(0, len(cks), es_bulk_size):
|
||||||
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||||
if b % 128 == 0:
|
if b % 128 == 0:
|
||||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||||
|
|
||||||
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||||
if es_r:
|
if es_r:
|
||||||
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
||||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||||
logging.error('Insert chunk error: ' + str(es_r))
|
logging.error('Insert chunk error: ' + str(es_r))
|
||||||
else:
|
else:
|
||||||
if TaskService.do_cancel(r["id"]):
|
if TaskService.do_cancel(r["id"]):
|
||||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||||
continue
|
continue
|
||||||
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
||||||
callback(1., "Done!")
|
callback(1., "Done!")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user