mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-21 05:29:57 +08:00
Move settings initialization after module init phase (#3438)
### What problem does this PR solve? 1. Module init won't connect database any more. 2. Config in settings need to be used with settings.CONFIG_NAME ### Type of change - [x] Refactoring Signed-off-by: jinhai <haijin.chn@gmail.com>
This commit is contained in:
parent
ac033b62cf
commit
1e90a1bf36
@ -19,7 +19,7 @@ import pandas as pd
|
||||
from api.db import LLMType
|
||||
from api.db.services.dialog_service import message_fit_in
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from api import settings
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
@ -63,18 +63,20 @@ class Generate(ComponentBase):
|
||||
component_name = "Generate"
|
||||
|
||||
def get_dependent_components(self):
|
||||
cpnts = [para["component_id"] for para in self._param.parameters if para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
||||
cpnts = [para["component_id"] for para in self._param.parameters if
|
||||
para.get("component_id") and para["component_id"].lower().find("answer") < 0]
|
||||
return cpnts
|
||||
|
||||
def set_cite(self, retrieval_res, answer):
|
||||
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
|
||||
if "empty_response" in retrieval_res.columns:
|
||||
retrieval_res["empty_response"].fillna("", inplace=True)
|
||||
answer, idx = retrievaler.insert_citations(answer, [ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
||||
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
||||
self._canvas.get_embedding_model()), tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
answer, idx = settings.retrievaler.insert_citations(answer,
|
||||
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
|
||||
[ck["vector"] for _, ck in retrieval_res.iterrows()],
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
|
||||
self._canvas.get_embedding_model()), tkweight=0.7,
|
||||
vtweight=0.3)
|
||||
doc_ids = set([])
|
||||
recall_docs = []
|
||||
for i in idx:
|
||||
@ -127,12 +129,14 @@ class Generate(ComponentBase):
|
||||
else:
|
||||
if cpn.component_name.lower() == "retrieval":
|
||||
retrieval_res.append(out)
|
||||
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
|
||||
kwargs[para["key"]] = " - " + "\n - ".join(
|
||||
[o if isinstance(o, str) else str(o) for o in out["content"]])
|
||||
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
|
||||
|
||||
if retrieval_res:
|
||||
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
|
||||
else: retrieval_res = pd.DataFrame([])
|
||||
else:
|
||||
retrieval_res = pd.DataFrame([])
|
||||
|
||||
for n, v in kwargs.items():
|
||||
prompt = re.sub(r"\{%s\}" % re.escape(n), re.escape(str(v)), prompt)
|
||||
|
@ -21,7 +21,7 @@ import pandas as pd
|
||||
from api.db import LLMType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from api import settings
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ class Retrieval(ComponentBase, ABC):
|
||||
if self._param.rerank_id:
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
kbinfos = retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
||||
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
||||
1, self._param.top_n,
|
||||
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
||||
aggs=False, rerank_mdl=rerank_mdl)
|
||||
|
@ -30,8 +30,7 @@ from api.utils import CustomJSONEncoder, commands
|
||||
|
||||
from flask_session import Session
|
||||
from flask_login import LoginManager
|
||||
from api.settings import SECRET_KEY
|
||||
from api.settings import API_VERSION
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response
|
||||
from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
|
||||
@ -78,7 +77,6 @@ app.url_map.strict_slashes = False
|
||||
app.json_encoder = CustomJSONEncoder
|
||||
app.errorhandler(Exception)(server_error_response)
|
||||
|
||||
|
||||
## convince for dev and debug
|
||||
# app.config["LOGIN_DISABLED"] = True
|
||||
app.config["SESSION_PERMANENT"] = False
|
||||
@ -110,7 +108,7 @@ def register_page(page_path):
|
||||
|
||||
page_name = page_path.stem.rstrip("_app")
|
||||
module_name = ".".join(
|
||||
page_path.parts[page_path.parts.index("api") : -1] + (page_name,)
|
||||
page_path.parts[page_path.parts.index("api"): -1] + (page_name,)
|
||||
)
|
||||
|
||||
spec = spec_from_file_location(module_name, page_path)
|
||||
@ -121,7 +119,7 @@ def register_page(page_path):
|
||||
spec.loader.exec_module(page)
|
||||
page_name = getattr(page, "page_name", page_name)
|
||||
url_prefix = (
|
||||
f"/api/{API_VERSION}" if "/sdk/" in path else f"/{API_VERSION}/{page_name}"
|
||||
f"/api/{settings.API_VERSION}" if "/sdk/" in path else f"/{settings.API_VERSION}/{page_name}"
|
||||
)
|
||||
|
||||
app.register_blueprint(page.manager, url_prefix=url_prefix)
|
||||
@ -141,7 +139,7 @@ client_urls_prefix = [
|
||||
|
||||
@login_manager.request_loader
|
||||
def load_user(web_request):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
authorization = web_request.headers.get("Authorization")
|
||||
if authorization:
|
||||
try:
|
||||
|
@ -32,7 +32,7 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.settings import RetCode, retrievaler
|
||||
from api import settings
|
||||
from api.utils import get_uuid, current_timestamp, datetime_format
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
@ -141,7 +141,7 @@ def set_conversation():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
@ -183,7 +183,7 @@ def completion():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
@ -290,8 +290,8 @@ def completion():
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
#******************For dialog******************
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
@ -326,7 +326,7 @@ def completion():
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
@ -347,8 +347,8 @@ def get(conversation_id):
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
@ -357,8 +357,8 @@ def get(conversation_id):
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Token is not valid for this conversation_id!"',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
@ -378,7 +378,7 @@ def upload():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
@ -394,12 +394,12 @@ def upload():
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
@ -490,17 +490,17 @@ def upload_parse():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
@ -513,7 +513,7 @@ def list_chunks():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
@ -531,7 +531,7 @@ def list_chunks():
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = settings.retrievaler.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
@ -553,7 +553,7 @@ def list_kb_docs():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
@ -585,6 +585,7 @@ def list_kb_docs():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST'])
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
@ -592,7 +593,7 @@ def docinfos():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
@ -606,7 +607,7 @@ def document_rm():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
@ -653,7 +654,7 @@ def document_rm():
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -668,7 +669,7 @@ def completion_faq():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
@ -805,10 +806,10 @@ def retrieval():
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Token is not valid!"', code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id",[])
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
@ -822,20 +823,21 @@ def retrieval():
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."', code=RetCode.AUTHENTICATION_ERROR)
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = TenantLLMService.model_instance(
|
||||
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl)
|
||||
for c in ranks["chunks"]:
|
||||
if "vector" in c:
|
||||
del c["vector"]
|
||||
@ -843,5 +845,5 @@ def retrieval():
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
@ -19,7 +19,7 @@ from functools import partial
|
||||
from flask import request, Response
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
|
||||
from agent.canvas import Canvas
|
||||
@ -36,7 +36,8 @@ def templates():
|
||||
@login_required
|
||||
def canvas_list():
|
||||
return get_json_result(data=sorted([c.to_dict() for c in \
|
||||
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
|
||||
UserCanvasService.query(user_id=current_user.id)],
|
||||
key=lambda x: x["update_time"] * -1)
|
||||
)
|
||||
|
||||
|
||||
@ -45,10 +46,10 @@ def canvas_list():
|
||||
@login_required
|
||||
def rm():
|
||||
for i in request.json["canvas_ids"]:
|
||||
if not UserCanvasService.query(user_id=current_user.id,id=i):
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=i):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.delete_by_id(i)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -72,7 +73,7 @@ def save():
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
return get_json_result(data=req)
|
||||
|
||||
@ -98,7 +99,7 @@ def run():
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
@ -110,8 +111,8 @@ def run():
|
||||
if "message" in req:
|
||||
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
|
||||
if len([m for m in canvas.messages if m["role"] == "user"]) > 1:
|
||||
#ten = TenantService.get_info_by(current_user.id)[0]
|
||||
#req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
||||
# ten = TenantService.get_info_by(current_user.id)[0]
|
||||
# req["message"] = full_question(ten["tenant_id"], ten["llm_id"], canvas.messages)
|
||||
pass
|
||||
canvas.add_user_input(req["message"])
|
||||
answer = canvas.run(stream=stream)
|
||||
@ -122,7 +123,8 @@ def run():
|
||||
assert answer is not None, "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
||||
assert isinstance(answer,
|
||||
partial), "The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow."
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs
|
||||
@ -173,7 +175,7 @@ def reset():
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=req["id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
|
||||
canvas.reset()
|
||||
|
@ -29,11 +29,12 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode, retrievaler, kg_retrievaler, docStoreConn
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
|
||||
@manager.route('/list', methods=['POST'])
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@ -56,7 +57,7 @@ def list_chunk():
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True)
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
@ -72,13 +73,13 @@ def list_chunk():
|
||||
"positions": json.loads(sres.field[id].get("position_list", "[]")),
|
||||
}
|
||||
assert isinstance(d["positions"], list)
|
||||
assert len(d["positions"])==0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||
res["chunks"].append(d)
|
||||
return get_json_result(data=res)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -93,7 +94,7 @@ def get():
|
||||
tenant_id = tenants[0].tenant_id
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response("Chunk not found")
|
||||
k = []
|
||||
@ -107,7 +108,7 @@ def get():
|
||||
except Exception as e:
|
||||
if str(e).find("NotFoundError") >= 0:
|
||||
return get_json_result(data=False, message='Chunk not found!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -154,7 +155,7 @@ def set():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -169,8 +170,8 @@ def switch():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
||||
search.index_name(doc.tenant_id), doc.kb_id):
|
||||
if not settings.docStoreConn.update({"id": req["chunk_ids"]}, {"available_int": int(req["available_int"])},
|
||||
search.index_name(doc.tenant_id), doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -186,7 +187,7 @@ def rm():
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if not docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
||||
if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id):
|
||||
return get_data_error_result(message="Index updating failure")
|
||||
deleted_chunk_ids = req["chunk_ids"]
|
||||
chunk_number = len(deleted_chunk_ids)
|
||||
@ -230,7 +231,7 @@ def create():
|
||||
v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc.id, doc.kb_id, c, 1, 0)
|
||||
@ -265,7 +266,7 @@ def retrieval_test():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
if not e:
|
||||
@ -281,7 +282,7 @@ def retrieval_test():
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||
ranks = retr.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
|
||||
@ -293,7 +294,7 @@ def retrieval_test():
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
code=settings.RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -304,10 +305,10 @@ def knowledge_graph():
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
req = {
|
||||
"doc_ids":[doc_id],
|
||||
"doc_ids": [doc_id],
|
||||
"knowledge_graph_kwd": ["graph", "mind_map"]
|
||||
}
|
||||
sres = retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||
sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids)
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
for id in sres.ids[:2]:
|
||||
ty = sres.field[id]["knowledge_graph_kwd"]
|
||||
@ -336,4 +337,3 @@ def knowledge_graph():
|
||||
obj[ty] = content_json
|
||||
|
||||
return get_json_result(data=obj)
|
||||
|
||||
|
@ -25,7 +25,7 @@ from api.db import LLMType
|
||||
from api.db.services.dialog_service import DialogService, ConversationService, chat, ask
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
||||
from api.settings import RetCode, retrievaler
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from graphrag.mind_map_extractor import MindMapExtractor
|
||||
@ -87,7 +87,7 @@ def get():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of conversation authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
conv = conv.to_dict()
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
@ -110,7 +110,7 @@ def rm():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of conversation authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
ConversationService.delete_by_id(cid)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -125,7 +125,7 @@ def list_convsersation():
|
||||
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dialog authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
convs = ConversationService.query(
|
||||
dialog_id=dialog_id,
|
||||
order_by=ConversationService.model.create_time,
|
||||
@ -297,6 +297,7 @@ def thumbup():
|
||||
def ask_about():
|
||||
req = request.json
|
||||
uid = current_user.id
|
||||
|
||||
def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
@ -329,8 +330,8 @@ def mindmap():
|
||||
embd_mdl = TenantLLMService.model_instance(
|
||||
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
||||
ranks = retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
||||
0.3, 0.3, aggs=False)
|
||||
ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
||||
0.3, 0.3, aggs=False)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
||||
if "error" in mind_map:
|
||||
|
@ -20,7 +20,7 @@ from api.db.services.dialog_service import DialogService
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
@ -175,7 +175,7 @@ def rm():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of dialog authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
|
||||
DialogService.update_many_by_id(dialog_list)
|
||||
return get_json_result(data=True)
|
||||
|
@ -34,7 +34,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType, TaskStatus, ParserType, FileSource
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.settings import RetCode, docStoreConn
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from api.utils.file_utils import filename_type, thumbnail, get_project_base_directory
|
||||
@ -49,16 +49,16 @@ def upload():
|
||||
kb_id = request.form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
@ -67,7 +67,7 @@ def upload():
|
||||
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
||||
if err:
|
||||
return get_json_result(
|
||||
data=False, message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@ -78,12 +78,12 @@ def web_crawl():
|
||||
kb_id = request.form.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
name = request.form.get("name")
|
||||
url = request.form.get("url")
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(
|
||||
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
@ -145,7 +145,7 @@ def create():
|
||||
kb_id = req["kb_id"]
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
@ -179,7 +179,7 @@ def list_docs():
|
||||
kb_id = request.args.get("kb_id")
|
||||
if not kb_id:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='Lack of "KB ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(
|
||||
@ -188,7 +188,7 @@ def list_docs():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
keywords = request.args.get("keywords", "")
|
||||
|
||||
page_number = int(request.args.get("page", 1))
|
||||
@ -218,19 +218,19 @@ def docinfos():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/thumbnails', methods=['GET'])
|
||||
#@login_required
|
||||
# @login_required
|
||||
def thumbnails():
|
||||
doc_ids = request.args.get("doc_ids").split(",")
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "Document ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='Lack of "Document ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
try:
|
||||
docs = DocumentService.get_thumbnails(doc_ids)
|
||||
@ -253,13 +253,13 @@ def change_status():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='"Status" must be either 0 or 1!',
|
||||
code=RetCode.ARGUMENT_ERROR)
|
||||
code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
@ -276,7 +276,8 @@ def change_status():
|
||||
message="Database error (Document update)!")
|
||||
|
||||
status = int(req["status"])
|
||||
docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.update({"doc_id": req["doc_id"]}, {"available_int": status},
|
||||
search.index_name(kb.tenant_id), doc.kb_id)
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -295,7 +296,7 @@ def rm():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
root_folder = FileService.get_root_folder(current_user.id)
|
||||
@ -326,7 +327,7 @@ def rm():
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=False, message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
@ -341,7 +342,7 @@ def run():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
for id in req["doc_ids"]:
|
||||
@ -358,8 +359,8 @@ def run():
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
@ -383,7 +384,7 @@ def rename():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
@ -394,7 +395,7 @@ def rename():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="The extension of file can't be changed",
|
||||
code=RetCode.ARGUMENT_ERROR)
|
||||
code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
return get_data_error_result(
|
||||
@ -450,7 +451,7 @@ def change_parser():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
@ -483,8 +484,8 @@ def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
@ -509,13 +510,13 @@ def get_image(image_id):
|
||||
def upload_and_parse():
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id)
|
||||
|
||||
@ -529,7 +530,7 @@ def parse():
|
||||
if url:
|
||||
if not is_valid_url(url):
|
||||
return get_json_result(
|
||||
data=False, message='The URL format is invalid', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='The URL format is invalid', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
download_path = os.path.join(get_project_base_directory(), "logs/downloads")
|
||||
os.makedirs(download_path, exist_ok=True)
|
||||
from selenium.webdriver import Chrome, ChromeOptions
|
||||
@ -553,7 +554,7 @@ def parse():
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
txt = FileService.parse_docs(file_objs, current_user.id)
|
||||
|
@ -24,7 +24,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from api.utils import get_uuid
|
||||
from api.db import FileType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ def rm():
|
||||
file_ids = req["file_ids"]
|
||||
if not file_ids:
|
||||
return get_json_result(
|
||||
data=False, message='Lack of "Files ID"', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='Lack of "Files ID"', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
try:
|
||||
for file_id in file_ids:
|
||||
informs = File2DocumentService.get_by_file_id(file_id)
|
||||
|
@ -28,7 +28,7 @@ from api.utils import get_uuid
|
||||
from api.db import FileType, FileSource
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.file_service import FileService
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.utils.file_utils import filename_type
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
@ -46,13 +46,13 @@ def upload():
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file part!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
file_objs = request.files.getlist('file')
|
||||
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
data=False, message='No file selected!', code=settings.RetCode.ARGUMENT_ERROR)
|
||||
file_res = []
|
||||
try:
|
||||
for file_obj in file_objs:
|
||||
@ -134,7 +134,7 @@ def create():
|
||||
try:
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(
|
||||
data=False, message="Parent Folder Doesn't Exist!", code=RetCode.OPERATING_ERROR)
|
||||
data=False, message="Parent Folder Doesn't Exist!", code=settings.RetCode.OPERATING_ERROR)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_data_error_result(
|
||||
message="Duplicated folder name in the same folder.")
|
||||
@ -299,7 +299,7 @@ def rename():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="The extension of file can't be changed",
|
||||
code=RetCode.ARGUMENT_ERROR)
|
||||
code=settings.RetCode.ARGUMENT_ERROR)
|
||||
for file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if file.name == req["name"]:
|
||||
return get_data_error_result(
|
||||
|
@ -26,9 +26,8 @@ from api.utils import get_uuid
|
||||
from api.db import StatusEnum, FileSource
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.db_models import File
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from api.settings import docStoreConn
|
||||
from api import settings
|
||||
from rag.nlp import search
|
||||
|
||||
|
||||
@ -68,13 +67,13 @@ def update():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
if not KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"]):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e:
|
||||
@ -113,7 +112,7 @@ def detail():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
code=settings.RetCode.OPERATING_ERROR)
|
||||
kb = KnowledgebaseService.get_detail(kb_id)
|
||||
if not kb:
|
||||
return get_data_error_result(
|
||||
@ -148,14 +147,14 @@ def rm():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
try:
|
||||
kbs = KnowledgebaseService.query(
|
||||
created_by=current_user.id, id=req["kb_id"])
|
||||
if not kbs:
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.', code=RetCode.OPERATING_ERROR)
|
||||
data=False, message='Only owner of knowledgebase authorized for this operation.', code=settings.RetCode.OPERATING_ERROR)
|
||||
|
||||
for doc in DocumentService.query(kb_id=req["kb_id"]):
|
||||
if not DocumentService.remove_document(doc, kbs[0].tenant_id):
|
||||
@ -170,7 +169,7 @@ def rm():
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
||||
settings.docStoreConn.deleteIdx(search.index_name(tenant.tenant_id), req["kb_id"])
|
||||
return get_json_result(data=True)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
@ -19,7 +19,7 @@ import json
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||
from api.settings import LIGHTEN
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
@ -333,7 +333,7 @@ def my_llms():
|
||||
@login_required
|
||||
def list_app():
|
||||
self_deploied = ["Youdao","FastEmbed", "BAAI", "Ollama", "Xinference", "LocalAI", "LM-Studio"]
|
||||
weighted = ["Youdao","FastEmbed", "BAAI"] if LIGHTEN != 0 else []
|
||||
weighted = ["Youdao","FastEmbed", "BAAI"] if settings.LIGHTEN != 0 else []
|
||||
model_type = request.args.get("model_type")
|
||||
try:
|
||||
objs = TenantLLMService.query(tenant_id=current_user.id)
|
||||
|
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
@ -44,7 +44,7 @@ def create(tenant_id):
|
||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||
embd_count = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_count) != 1:
|
||||
return get_result(message='Datasets use different embedding models."',code=RetCode.AUTHENTICATION_ERROR)
|
||||
return get_result(message='Datasets use different embedding models."',code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
# llm
|
||||
llm = req.get("llm")
|
||||
@ -173,7 +173,7 @@ def update(tenant_id,chat_id):
|
||||
if len(embd_count) != 1 :
|
||||
return get_result(
|
||||
message='Datasets use different embedding models."',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
llm = req.get("llm")
|
||||
if llm:
|
||||
|
@ -23,7 +23,7 @@ from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService, LLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import (
|
||||
get_result,
|
||||
@ -255,7 +255,7 @@ def delete(tenant_id):
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
if not KnowledgebaseService.delete_by_id(id):
|
||||
return get_error_data_result(message="Delete dataset error.(Database error)")
|
||||
return get_result(code=RetCode.SUCCESS)
|
||||
return get_result(code=settings.RetCode.SUCCESS)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"])
|
||||
@ -424,7 +424,7 @@ def update(tenant_id, dataset_id):
|
||||
)
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
return get_result(code=RetCode.SUCCESS)
|
||||
return get_result(code=settings.RetCode.SUCCESS)
|
||||
|
||||
|
||||
@manager.route("/datasets", methods=["GET"])
|
||||
|
@ -18,7 +18,7 @@ from flask import request, jsonify
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler, kg_retrievaler, RetCode
|
||||
from api import settings
|
||||
from api.utils.api_utils import validate_request, build_error_result, apikey_required
|
||||
|
||||
|
||||
@ -37,14 +37,14 @@ def retrieval(tenant_id):
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
if not e:
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
||||
|
||||
if kb.tenant_id != tenant_id:
|
||||
return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND)
|
||||
return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||
ranks = retr.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
@ -72,6 +72,6 @@ def retrieval(tenant_id):
|
||||
if str(e).find("not_found") > 0:
|
||||
return build_error_result(
|
||||
message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.NOT_FOUND
|
||||
code=settings.RetCode.NOT_FOUND
|
||||
)
|
||||
return build_error_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR)
|
||||
|
@ -21,7 +21,7 @@ from rag.app.qa import rmPrefix, beAdoc
|
||||
from rag.nlp import rag_tokenizer
|
||||
from api.db import LLMType, ParserType
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.settings import kg_retrievaler
|
||||
from api import settings
|
||||
import hashlib
|
||||
import re
|
||||
from api.utils.api_utils import token_required
|
||||
@ -37,11 +37,10 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.settings import RetCode, retrievaler
|
||||
from api import settings
|
||||
from api.utils.api_utils import construct_json_result, get_parser_config
|
||||
from rag.nlp import search
|
||||
from rag.utils import rmSpace
|
||||
from api.settings import docStoreConn
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
import os
|
||||
|
||||
@ -109,13 +108,13 @@ def upload(dataset_id, tenant_id):
|
||||
"""
|
||||
if "file" not in request.files:
|
||||
return get_error_data_result(
|
||||
message="No file part!", code=RetCode.ARGUMENT_ERROR
|
||||
message="No file part!", code=settings.RetCode.ARGUMENT_ERROR
|
||||
)
|
||||
file_objs = request.files.getlist("file")
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == "":
|
||||
return get_result(
|
||||
message="No file selected!", code=RetCode.ARGUMENT_ERROR
|
||||
message="No file selected!", code=settings.RetCode.ARGUMENT_ERROR
|
||||
)
|
||||
# total size
|
||||
total_size = 0
|
||||
@ -127,14 +126,14 @@ def upload(dataset_id, tenant_id):
|
||||
if total_size > MAX_TOTAL_FILE_SIZE:
|
||||
return get_result(
|
||||
message=f"Total file size exceeds 10MB limit! ({total_size / (1024 * 1024):.2f} MB)",
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
code=settings.RetCode.ARGUMENT_ERROR,
|
||||
)
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
return get_result(message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
renamed_doc_list = []
|
||||
for file in files:
|
||||
@ -221,12 +220,12 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
|
||||
if "name" in req and req["name"] != doc.name:
|
||||
if (
|
||||
pathlib.Path(req["name"].lower()).suffix
|
||||
!= pathlib.Path(doc.name.lower()).suffix
|
||||
pathlib.Path(req["name"].lower()).suffix
|
||||
!= pathlib.Path(doc.name.lower()).suffix
|
||||
):
|
||||
return get_result(
|
||||
message="The extension of file can't be changed",
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
code=settings.RetCode.ARGUMENT_ERROR,
|
||||
)
|
||||
for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id):
|
||||
if d.name == req["name"]:
|
||||
@ -292,7 +291,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
||||
)
|
||||
if not e:
|
||||
return get_error_data_result(message="Document not found!")
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
|
||||
return get_result()
|
||||
|
||||
@ -349,7 +348,7 @@ def download(tenant_id, dataset_id, document_id):
|
||||
file_stream = STORAGE_IMPL.get(doc_id, doc_location)
|
||||
if not file_stream:
|
||||
return construct_json_result(
|
||||
message="This file is empty.", code=RetCode.DATA_ERROR
|
||||
message="This file is empty.", code=settings.RetCode.DATA_ERROR
|
||||
)
|
||||
file = BytesIO(file_stream)
|
||||
# Use send_file with a proper filename and MIME type
|
||||
@ -582,7 +581,7 @@ def delete(tenant_id, dataset_id):
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_result(message=errors, code=RetCode.SERVER_ERROR)
|
||||
return get_result(message=errors, code=settings.RetCode.SERVER_ERROR)
|
||||
|
||||
return get_result()
|
||||
|
||||
@ -644,7 +643,7 @@ def parse(tenant_id, dataset_id):
|
||||
info["chunk_num"] = 0
|
||||
info["token_num"] = 0
|
||||
DocumentService.update_by_id(id, info)
|
||||
docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), dataset_id)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
e, doc = DocumentService.get_by_id(id)
|
||||
doc = doc.to_dict()
|
||||
@ -708,7 +707,7 @@ def stop_parsing(tenant_id, dataset_id):
|
||||
)
|
||||
info = {"run": "2", "progress": 0, "chunk_num": 0}
|
||||
DocumentService.update_by_id(id, info)
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), dataset_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -828,8 +827,9 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
|
||||
res = {"total": 0, "chunks": [], "doc": renamed_doc}
|
||||
origin_chunks = []
|
||||
if docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retrievaler.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None,
|
||||
highlight=True)
|
||||
res["total"] = sres.total
|
||||
sign = 0
|
||||
for id in sres.ids:
|
||||
@ -1003,7 +1003,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
|
||||
v, c = embd_mdl.encode([doc.name, req["content"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id)
|
||||
|
||||
DocumentService.increment_chunk_num(doc.id, doc.kb_id, c, 1, 0)
|
||||
# rename keys
|
||||
@ -1078,7 +1078,7 @@ def rm_chunk(tenant_id, dataset_id, document_id):
|
||||
condition = {"doc_id": document_id}
|
||||
if "chunk_ids" in req:
|
||||
condition["id"] = req["chunk_ids"]
|
||||
chunk_number = docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id)
|
||||
if chunk_number != 0:
|
||||
DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0)
|
||||
if "chunk_ids" in req and chunk_number != len(req["chunk_ids"]):
|
||||
@ -1143,7 +1143,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
schema:
|
||||
type: object
|
||||
"""
|
||||
chunk = docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id])
|
||||
if chunk is None:
|
||||
return get_error_data_result(f"Can't find this chunk {chunk_id}")
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
@ -1187,7 +1187,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id):
|
||||
v, c = embd_mdl.encode([doc.name, d["content_with_weight"]])
|
||||
v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
|
||||
d["q_%d_vec" % len(v)] = v.tolist()
|
||||
docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@ -1285,7 +1285,7 @@ def retrieval_test(tenant_id):
|
||||
if len(embd_nms) != 1:
|
||||
return get_result(
|
||||
message='Datasets use different embedding models."',
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
)
|
||||
if "question" not in req:
|
||||
return get_error_data_result("`question` is required.")
|
||||
@ -1326,7 +1326,7 @@ def retrieval_test(tenant_id):
|
||||
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
|
||||
retr = retrievaler if kb.parser_id != ParserType.KG else kg_retrievaler
|
||||
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
||||
ranks = retr.retrieval(
|
||||
question,
|
||||
embd_mdl,
|
||||
@ -1366,6 +1366,6 @@ def retrieval_test(tenant_id):
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_result(
|
||||
message="No chunk found! Check the chunk status please!",
|
||||
code=RetCode.DATA_ERROR,
|
||||
code=settings.RetCode.DATA_ERROR,
|
||||
)
|
||||
return server_error_response(e)
|
||||
return server_error_response(e)
|
||||
|
@ -22,7 +22,7 @@ from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.settings import DATABASE_TYPE
|
||||
from api import settings
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
@ -31,7 +31,6 @@ from api.utils.api_utils import (
|
||||
generate_confirmation_token,
|
||||
)
|
||||
from api.versions import get_ragflow_version
|
||||
from api.settings import docStoreConn
|
||||
from rag.utils.storage_factory import STORAGE_IMPL, STORAGE_IMPL_TYPE
|
||||
from timeit import default_timer as timer
|
||||
|
||||
@ -98,7 +97,7 @@ def status():
|
||||
res = {}
|
||||
st = timer()
|
||||
try:
|
||||
res["doc_store"] = docStoreConn.health()
|
||||
res["doc_store"] = settings.docStoreConn.health()
|
||||
res["doc_store"]["elapsed"] = "{:.1f}".format((timer() - st) * 1000.0)
|
||||
except Exception as e:
|
||||
res["doc_store"] = {
|
||||
@ -128,13 +127,13 @@ def status():
|
||||
try:
|
||||
KnowledgebaseService.get_by_id("x")
|
||||
res["database"] = {
|
||||
"database": DATABASE_TYPE.lower(),
|
||||
"database": settings.DATABASE_TYPE.lower(),
|
||||
"status": "green",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
}
|
||||
except Exception as e:
|
||||
res["database"] = {
|
||||
"database": DATABASE_TYPE.lower(),
|
||||
"database": settings.DATABASE_TYPE.lower(),
|
||||
"status": "red",
|
||||
"elapsed": "{:.1f}".format((timer() - st) * 1000.0),
|
||||
"error": str(e),
|
||||
|
@ -38,20 +38,7 @@ from api.utils import (
|
||||
datetime_format,
|
||||
)
|
||||
from api.db import UserTenantRole, FileType
|
||||
from api.settings import (
|
||||
RetCode,
|
||||
GITHUB_OAUTH,
|
||||
FEISHU_OAUTH,
|
||||
CHAT_MDL,
|
||||
EMBEDDING_MDL,
|
||||
ASR_MDL,
|
||||
IMAGE2TEXT_MDL,
|
||||
PARSERS,
|
||||
API_KEY,
|
||||
LLM_FACTORY,
|
||||
LLM_BASE_URL,
|
||||
RERANK_MDL,
|
||||
)
|
||||
from api import settings
|
||||
from api.db.services.user_service import UserService, TenantService, UserTenantService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.utils.api_utils import get_json_result, construct_response
|
||||
@ -90,7 +77,7 @@ def login():
|
||||
"""
|
||||
if not request.json:
|
||||
return get_json_result(
|
||||
data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
||||
data=False, code=settings.RetCode.AUTHENTICATION_ERROR, message="Unauthorized!"
|
||||
)
|
||||
|
||||
email = request.json.get("email", "")
|
||||
@ -98,7 +85,7 @@ def login():
|
||||
if not users:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
message=f"Email: {email} is not registered!",
|
||||
)
|
||||
|
||||
@ -107,7 +94,7 @@ def login():
|
||||
password = decrypt(password)
|
||||
except BaseException:
|
||||
return get_json_result(
|
||||
data=False, code=RetCode.SERVER_ERROR, message="Fail to crypt password"
|
||||
data=False, code=settings.RetCode.SERVER_ERROR, message="Fail to crypt password"
|
||||
)
|
||||
|
||||
user = UserService.query_user(email, password)
|
||||
@ -123,7 +110,7 @@ def login():
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
message="Email and password do not match!",
|
||||
)
|
||||
|
||||
@ -150,10 +137,10 @@ def github_callback():
|
||||
import requests
|
||||
|
||||
res = requests.post(
|
||||
GITHUB_OAUTH.get("url"),
|
||||
settings.GITHUB_OAUTH.get("url"),
|
||||
data={
|
||||
"client_id": GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": GITHUB_OAUTH.get("secret_key"),
|
||||
"client_id": settings.GITHUB_OAUTH.get("client_id"),
|
||||
"client_secret": settings.GITHUB_OAUTH.get("secret_key"),
|
||||
"code": request.args.get("code"),
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
@ -235,11 +222,11 @@ def feishu_callback():
|
||||
import requests
|
||||
|
||||
app_access_token_res = requests.post(
|
||||
FEISHU_OAUTH.get("app_access_token_url"),
|
||||
settings.FEISHU_OAUTH.get("app_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"app_id": FEISHU_OAUTH.get("app_id"),
|
||||
"app_secret": FEISHU_OAUTH.get("app_secret"),
|
||||
"app_id": settings.FEISHU_OAUTH.get("app_id"),
|
||||
"app_secret": settings.FEISHU_OAUTH.get("app_secret"),
|
||||
}
|
||||
),
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
@ -249,10 +236,10 @@ def feishu_callback():
|
||||
return redirect("/?error=%s" % app_access_token_res)
|
||||
|
||||
res = requests.post(
|
||||
FEISHU_OAUTH.get("user_access_token_url"),
|
||||
settings.FEISHU_OAUTH.get("user_access_token_url"),
|
||||
data=json.dumps(
|
||||
{
|
||||
"grant_type": FEISHU_OAUTH.get("grant_type"),
|
||||
"grant_type": settings.FEISHU_OAUTH.get("grant_type"),
|
||||
"code": request.args.get("code"),
|
||||
}
|
||||
),
|
||||
@ -405,11 +392,11 @@ def setting_user():
|
||||
if request_data.get("password"):
|
||||
new_password = request_data.get("new_password")
|
||||
if not check_password_hash(
|
||||
current_user.password, decrypt(request_data["password"])
|
||||
current_user.password, decrypt(request_data["password"])
|
||||
):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR,
|
||||
message="Password error!",
|
||||
)
|
||||
|
||||
@ -438,7 +425,7 @@ def setting_user():
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(
|
||||
data=False, message="Update failure!", code=RetCode.EXCEPTION_ERROR
|
||||
data=False, message="Update failure!", code=settings.RetCode.EXCEPTION_ERROR
|
||||
)
|
||||
|
||||
|
||||
@ -497,12 +484,12 @@ def user_register(user_id, user):
|
||||
tenant = {
|
||||
"id": user_id,
|
||||
"name": user["nickname"] + "‘s Kingdom",
|
||||
"llm_id": CHAT_MDL,
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL,
|
||||
"rerank_id": RERANK_MDL,
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL,
|
||||
"rerank_id": settings.RERANK_MDL,
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_id,
|
||||
@ -522,15 +509,15 @@ def user_register(user_id, user):
|
||||
"location": "",
|
||||
}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": LLM_FACTORY,
|
||||
"llm_factory": settings.LLM_FACTORY,
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": API_KEY,
|
||||
"api_base": LLM_BASE_URL,
|
||||
"api_key": settings.API_KEY,
|
||||
"api_base": settings.LLM_BASE_URL,
|
||||
}
|
||||
)
|
||||
|
||||
@ -582,7 +569,7 @@ def user_add():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"Invalid email address: {email_address}!",
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
code=settings.RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
# Check if the email address is already used
|
||||
@ -590,7 +577,7 @@ def user_add():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"Email: {email_address} has already registered!",
|
||||
code=RetCode.OPERATING_ERROR,
|
||||
code=settings.RetCode.OPERATING_ERROR,
|
||||
)
|
||||
|
||||
# Construct user info data
|
||||
@ -625,7 +612,7 @@ def user_add():
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message=f"User registration failure, error: {str(e)}",
|
||||
code=RetCode.EXCEPTION_ERROR,
|
||||
code=settings.RetCode.EXCEPTION_ERROR,
|
||||
)
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ from peewee import (
|
||||
)
|
||||
from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
|
||||
from api.db import SerializedType, ParserType
|
||||
from api.settings import DATABASE, SECRET_KEY, DATABASE_TYPE
|
||||
from api import settings
|
||||
from api import utils
|
||||
|
||||
def singleton(cls, *args, **kw):
|
||||
@ -62,7 +62,7 @@ class TextFieldType(Enum):
|
||||
|
||||
|
||||
class LongTextField(TextField):
|
||||
field_type = TextFieldType[DATABASE_TYPE.upper()].value
|
||||
field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
|
||||
|
||||
|
||||
class JSONField(LongTextField):
|
||||
@ -282,9 +282,9 @@ class DatabaseMigrator(Enum):
|
||||
@singleton
|
||||
class BaseDataBase:
|
||||
def __init__(self):
|
||||
database_config = DATABASE.copy()
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
|
||||
logging.info('init database on cluster mode successfully')
|
||||
|
||||
class PostgresDatabaseLock:
|
||||
@ -385,7 +385,7 @@ class DatabaseLock(Enum):
|
||||
|
||||
|
||||
DB = BaseDataBase().database_connection
|
||||
DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
|
||||
DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
|
||||
|
||||
|
||||
def close_connection():
|
||||
@ -476,7 +476,7 @@ class User(DataBaseModel, UserMixin):
|
||||
return self.email
|
||||
|
||||
def get_id(self):
|
||||
jwt = Serializer(secret_key=SECRET_KEY)
|
||||
jwt = Serializer(secret_key=settings.SECRET_KEY)
|
||||
return jwt.dumps(str(self.access_token))
|
||||
|
||||
class Meta:
|
||||
@ -977,7 +977,7 @@ class CanvasTemplate(DataBaseModel):
|
||||
|
||||
def migrate_db():
|
||||
with DB.transaction():
|
||||
migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
|
||||
|
@ -29,7 +29,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY, LLM_BASE_URL
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
|
||||
|
||||
@ -51,11 +51,11 @@ def init_superuser():
|
||||
tenant = {
|
||||
"id": user_info["id"],
|
||||
"name": user_info["nickname"] + "‘s Kingdom",
|
||||
"llm_id": CHAT_MDL,
|
||||
"embd_id": EMBEDDING_MDL,
|
||||
"asr_id": ASR_MDL,
|
||||
"parser_ids": PARSERS,
|
||||
"img2txt_id": IMAGE2TEXT_MDL
|
||||
"llm_id": settings.CHAT_MDL,
|
||||
"embd_id": settings.EMBEDDING_MDL,
|
||||
"asr_id": settings.ASR_MDL,
|
||||
"parser_ids": settings.PARSERS,
|
||||
"img2txt_id": settings.IMAGE2TEXT_MDL
|
||||
}
|
||||
usr_tenant = {
|
||||
"tenant_id": user_info["id"],
|
||||
@ -64,10 +64,11 @@ def init_superuser():
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
tenant_llm = []
|
||||
for llm in LLMService.query(fid=LLM_FACTORY):
|
||||
for llm in LLMService.query(fid=settings.LLM_FACTORY):
|
||||
tenant_llm.append(
|
||||
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
|
||||
"api_key": API_KEY, "api_base": LLM_BASE_URL})
|
||||
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
@ -80,7 +81,7 @@ def init_superuser():
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error(
|
||||
"'{}' dosen't work. {}".format(
|
||||
@ -179,7 +180,7 @@ def init_web_data():
|
||||
start_time = time.time()
|
||||
|
||||
init_llm_factory()
|
||||
#if not UserService.get_all().count():
|
||||
# if not UserService.get_all().count():
|
||||
# init_superuser()
|
||||
|
||||
add_graph_templates()
|
||||
|
@ -27,7 +27,7 @@ from api.db.db_models import Dialog, Conversation,DB
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.settings import retrievaler, kg_retrievaler
|
||||
from api import settings
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
@ -152,7 +152,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = retrievaler if not is_kg else kg_retrievaler
|
||||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||||
@ -342,7 +342,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||||
|
||||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||||
tried_times += 1
|
||||
return retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
tbl, sql = get_table()
|
||||
if tbl is None:
|
||||
@ -596,7 +596,7 @@ def ask(question, kb_ids, tenant_id):
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
|
||||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||||
retr = retrievaler if not is_kg else kg_retrievaler
|
||||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||||
|
||||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
|
@ -26,7 +26,7 @@ from io import BytesIO
|
||||
from peewee import fn
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.settings import docStoreConn
|
||||
from api import settings
|
||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||
from graphrag.mind_map_extractor import MindMapExtractor
|
||||
from rag.settings import SVR_QUEUE_NAME
|
||||
@ -108,7 +108,7 @@ class DocumentService(CommonService):
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def remove_document(cls, doc, tenant_id):
|
||||
docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
cls.clear_chunk_num(doc.id)
|
||||
return cls.delete_by_id(doc.id)
|
||||
|
||||
@ -553,10 +553,10 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not docStoreConn.indexExist(idxnm, kb_id):
|
||||
docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
try_create_idx = False
|
||||
docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
DocumentService.increment_chunk_num(
|
||||
doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)
|
||||
|
@ -33,12 +33,10 @@ import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from werkzeug.serving import run_simple
|
||||
from api import settings
|
||||
from api.apps import app
|
||||
from api.db.runtime_config import RuntimeConfig
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.settings import (
|
||||
HOST, HTTP_PORT
|
||||
)
|
||||
from api import utils
|
||||
|
||||
from api.db.db_models import init_database_tables as init_web_db
|
||||
@ -72,6 +70,7 @@ if __name__ == '__main__':
|
||||
f'project base: {utils.file_utils.get_project_base_directory()}'
|
||||
)
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
|
||||
# init db
|
||||
init_web_db()
|
||||
@ -96,7 +95,7 @@ if __name__ == '__main__':
|
||||
logging.info("run on debug mode")
|
||||
|
||||
RuntimeConfig.init_env()
|
||||
RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
|
||||
RuntimeConfig.init_config(JOB_SERVER_HOST=settings.HOST_IP, HTTP_PORT=settings.HOST_PORT)
|
||||
|
||||
thread = ThreadPoolExecutor(max_workers=1)
|
||||
thread.submit(update_progress)
|
||||
@ -105,8 +104,8 @@ if __name__ == '__main__':
|
||||
try:
|
||||
logging.info("RAGFlow HTTP server start...")
|
||||
run_simple(
|
||||
hostname=HOST,
|
||||
port=HTTP_PORT,
|
||||
hostname=settings.HOST_IP,
|
||||
port=settings.HOST_PORT,
|
||||
application=app,
|
||||
threaded=True,
|
||||
use_reloader=RuntimeConfig.DEBUG,
|
||||
|
241
api/settings.py
241
api/settings.py
@ -30,114 +30,157 @@ LIGHTEN = int(os.environ.get('LIGHTEN', "0"))
|
||||
|
||||
REQUEST_WAIT_SEC = 2
|
||||
REQUEST_MAX_WAIT_SEC = 300
|
||||
|
||||
LLM = get_base_config("user_default_llm", {})
|
||||
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
||||
LLM_BASE_URL = LLM.get("base_url")
|
||||
|
||||
CHAT_MDL = EMBEDDING_MDL = RERANK_MDL = ASR_MDL = IMAGE2TEXT_MDL = ""
|
||||
if not LIGHTEN:
|
||||
default_llm = {
|
||||
"Tongyi-Qianwen": {
|
||||
"chat_model": "qwen-plus",
|
||||
"embedding_model": "text-embedding-v2",
|
||||
"image2text_model": "qwen-vl-max",
|
||||
"asr_model": "paraformer-realtime-8k-v1",
|
||||
},
|
||||
"OpenAI": {
|
||||
"chat_model": "gpt-3.5-turbo",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"image2text_model": "gpt-4-vision-preview",
|
||||
"asr_model": "whisper-1",
|
||||
},
|
||||
"Azure-OpenAI": {
|
||||
"chat_model": "gpt-35-turbo",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"image2text_model": "gpt-4-vision-preview",
|
||||
"asr_model": "whisper-1",
|
||||
},
|
||||
"ZHIPU-AI": {
|
||||
"chat_model": "glm-3-turbo",
|
||||
"embedding_model": "embedding-2",
|
||||
"image2text_model": "glm-4v",
|
||||
"asr_model": "",
|
||||
},
|
||||
"Ollama": {
|
||||
"chat_model": "qwen-14B-chat",
|
||||
"embedding_model": "flag-embedding",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"Moonshot": {
|
||||
"chat_model": "moonshot-v1-8k",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"DeepSeek": {
|
||||
"chat_model": "deepseek-chat",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"VolcEngine": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"BAAI": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
||||
}
|
||||
}
|
||||
|
||||
if LLM_FACTORY:
|
||||
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
|
||||
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
|
||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
|
||||
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
||||
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
||||
|
||||
API_KEY = LLM.get("api_key", "")
|
||||
PARSERS = LLM.get(
|
||||
"parsers",
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
||||
|
||||
HOST = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HTTP_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
SECRET_KEY = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME,
|
||||
{}).get("secret_key", str(date.today()))
|
||||
LLM = None
|
||||
LLM_FACTORY = None
|
||||
LLM_BASE_URL = None
|
||||
CHAT_MDL = ""
|
||||
EMBEDDING_MDL = ""
|
||||
RERANK_MDL = ""
|
||||
ASR_MDL = ""
|
||||
IMAGE2TEXT_MDL = ""
|
||||
API_KEY = None
|
||||
PARSERS = None
|
||||
HOST_IP = None
|
||||
HOST_PORT = None
|
||||
SECRET_KEY = None
|
||||
|
||||
DATABASE_TYPE = os.getenv("DB_TYPE", 'mysql')
|
||||
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
|
||||
|
||||
# authentication
|
||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||
AUTHENTICATION_CONF = None
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
||||
"client", {}).get(
|
||||
"switch", False)
|
||||
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||
CLIENT_AUTHENTICATION = None
|
||||
HTTP_APP_KEY = None
|
||||
GITHUB_OAUTH = None
|
||||
FEISHU_OAUTH = None
|
||||
|
||||
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
||||
if DOC_ENGINE == "elasticsearch":
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif DOC_ENGINE == "infinity":
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
DOC_ENGINE = None
|
||||
docStoreConn = None
|
||||
|
||||
retrievaler = search.Dealer(docStoreConn)
|
||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||
retrievaler = None
|
||||
kg_retrievaler = None
|
||||
|
||||
|
||||
def init_settings():
|
||||
global LLM, LLM_FACTORY, LLM_BASE_URL
|
||||
LLM = get_base_config("user_default_llm", {})
|
||||
LLM_FACTORY = LLM.get("factory", "Tongyi-Qianwen")
|
||||
LLM_BASE_URL = LLM.get("base_url")
|
||||
|
||||
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
|
||||
if not LIGHTEN:
|
||||
default_llm = {
|
||||
"Tongyi-Qianwen": {
|
||||
"chat_model": "qwen-plus",
|
||||
"embedding_model": "text-embedding-v2",
|
||||
"image2text_model": "qwen-vl-max",
|
||||
"asr_model": "paraformer-realtime-8k-v1",
|
||||
},
|
||||
"OpenAI": {
|
||||
"chat_model": "gpt-3.5-turbo",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"image2text_model": "gpt-4-vision-preview",
|
||||
"asr_model": "whisper-1",
|
||||
},
|
||||
"Azure-OpenAI": {
|
||||
"chat_model": "gpt-35-turbo",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"image2text_model": "gpt-4-vision-preview",
|
||||
"asr_model": "whisper-1",
|
||||
},
|
||||
"ZHIPU-AI": {
|
||||
"chat_model": "glm-3-turbo",
|
||||
"embedding_model": "embedding-2",
|
||||
"image2text_model": "glm-4v",
|
||||
"asr_model": "",
|
||||
},
|
||||
"Ollama": {
|
||||
"chat_model": "qwen-14B-chat",
|
||||
"embedding_model": "flag-embedding",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"Moonshot": {
|
||||
"chat_model": "moonshot-v1-8k",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"DeepSeek": {
|
||||
"chat_model": "deepseek-chat",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"VolcEngine": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
},
|
||||
"BAAI": {
|
||||
"chat_model": "",
|
||||
"embedding_model": "BAAI/bge-large-zh-v1.5",
|
||||
"image2text_model": "",
|
||||
"asr_model": "",
|
||||
"rerank_model": "BAAI/bge-reranker-v2-m3",
|
||||
}
|
||||
}
|
||||
|
||||
if LLM_FACTORY:
|
||||
CHAT_MDL = default_llm[LLM_FACTORY]["chat_model"] + f"@{LLM_FACTORY}"
|
||||
ASR_MDL = default_llm[LLM_FACTORY]["asr_model"] + f"@{LLM_FACTORY}"
|
||||
IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"] + f"@{LLM_FACTORY}"
|
||||
EMBEDDING_MDL = default_llm["BAAI"]["embedding_model"] + "@BAAI"
|
||||
RERANK_MDL = default_llm["BAAI"]["rerank_model"] + "@BAAI"
|
||||
|
||||
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
|
||||
API_KEY = LLM.get("api_key", "")
|
||||
PARSERS = LLM.get(
|
||||
"parsers",
|
||||
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
||||
|
||||
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
||||
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
||||
|
||||
SECRET_KEY = get_base_config(
|
||||
RAG_FLOW_SERVICE_NAME,
|
||||
{}).get("secret_key", str(date.today()))
|
||||
|
||||
global AUTHENTICATION_CONF, CLIENT_AUTHENTICATION, HTTP_APP_KEY, GITHUB_OAUTH, FEISHU_OAUTH
|
||||
# authentication
|
||||
AUTHENTICATION_CONF = get_base_config("authentication", {})
|
||||
|
||||
# client
|
||||
CLIENT_AUTHENTICATION = AUTHENTICATION_CONF.get(
|
||||
"client", {}).get(
|
||||
"switch", False)
|
||||
HTTP_APP_KEY = AUTHENTICATION_CONF.get("client", {}).get("http_app_key")
|
||||
GITHUB_OAUTH = get_base_config("oauth", {}).get("github")
|
||||
FEISHU_OAUTH = get_base_config("oauth", {}).get("feishu")
|
||||
|
||||
global DOC_ENGINE, docStoreConn, retrievaler, kg_retrievaler
|
||||
DOC_ENGINE = os.environ.get('DOC_ENGINE', "elasticsearch")
|
||||
if DOC_ENGINE == "elasticsearch":
|
||||
docStoreConn = rag.utils.es_conn.ESConnection()
|
||||
elif DOC_ENGINE == "infinity":
|
||||
docStoreConn = rag.utils.infinity_conn.InfinityConnection()
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
retrievaler = search.Dealer(docStoreConn)
|
||||
kg_retrievaler = kg_search.KGSearch(docStoreConn)
|
||||
|
||||
def get_host_ip():
|
||||
global HOST_IP
|
||||
return HOST_IP
|
||||
|
||||
|
||||
def get_host_port():
|
||||
global HOST_PORT
|
||||
return HOST_PORT
|
||||
|
||||
|
||||
class CustomEnum(Enum):
|
||||
|
@ -34,11 +34,9 @@ from itsdangerous import URLSafeTimedSerializer
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from api.db.db_models import APIToken
|
||||
from api.settings import (
|
||||
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
||||
CLIENT_AUTHENTICATION, HTTP_APP_KEY, SECRET_KEY
|
||||
)
|
||||
from api.settings import RetCode
|
||||
from api import settings
|
||||
|
||||
from api import settings
|
||||
from api.utils import CustomJSONEncoder, get_uuid
|
||||
from api.utils import json_dumps
|
||||
|
||||
@ -59,13 +57,13 @@ def request(**kwargs):
|
||||
{}).items()}
|
||||
prepped = requests.Request(**kwargs).prepare()
|
||||
|
||||
if CLIENT_AUTHENTICATION and HTTP_APP_KEY and SECRET_KEY:
|
||||
if settings.CLIENT_AUTHENTICATION and settings.HTTP_APP_KEY and settings.SECRET_KEY:
|
||||
timestamp = str(round(time() * 1000))
|
||||
nonce = str(uuid1())
|
||||
signature = b64encode(HMAC(SECRET_KEY.encode('ascii'), b'\n'.join([
|
||||
signature = b64encode(HMAC(settings.SECRET_KEY.encode('ascii'), b'\n'.join([
|
||||
timestamp.encode('ascii'),
|
||||
nonce.encode('ascii'),
|
||||
HTTP_APP_KEY.encode('ascii'),
|
||||
settings.HTTP_APP_KEY.encode('ascii'),
|
||||
prepped.path_url.encode('ascii'),
|
||||
prepped.body if kwargs.get('json') else b'',
|
||||
urlencode(
|
||||
@ -79,7 +77,7 @@ def request(**kwargs):
|
||||
prepped.headers.update({
|
||||
'TIMESTAMP': timestamp,
|
||||
'NONCE': nonce,
|
||||
'APP-KEY': HTTP_APP_KEY,
|
||||
'APP-KEY': settings.HTTP_APP_KEY,
|
||||
'SIGNATURE': signature,
|
||||
})
|
||||
|
||||
@ -89,7 +87,7 @@ def request(**kwargs):
|
||||
def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||
"""Calculate the exponential backoff wait time."""
|
||||
# Will be zero if factor equals 0
|
||||
countdown = min(REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC * (2 ** retries))
|
||||
countdown = min(settings.REQUEST_MAX_WAIT_SEC, settings.REQUEST_WAIT_SEC * (2 ** retries))
|
||||
# Full jitter according to
|
||||
# https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
|
||||
if full_jitter:
|
||||
@ -98,7 +96,7 @@ def get_exponential_backoff_interval(retries, full_jitter=False):
|
||||
return max(0, countdown)
|
||||
|
||||
|
||||
def get_data_error_result(code=RetCode.DATA_ERROR,
|
||||
def get_data_error_result(code=settings.RetCode.DATA_ERROR,
|
||||
message='Sorry! Data missing!'):
|
||||
import re
|
||||
result_dict = {
|
||||
@ -126,8 +124,8 @@ def server_error_response(e):
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return get_json_result(
|
||||
code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
return get_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
return get_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
|
||||
def error_response(response_code, message=None):
|
||||
@ -168,7 +166,7 @@ def validate_request(*args, **kwargs):
|
||||
error_string += "required argument values: {}".format(
|
||||
",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return get_json_result(
|
||||
code=RetCode.ARGUMENT_ERROR, message=error_string)
|
||||
code=settings.RetCode.ARGUMENT_ERROR, message=error_string)
|
||||
return func(*_args, **_kwargs)
|
||||
|
||||
return decorated_function
|
||||
@ -193,7 +191,7 @@ def send_file_in_mem(data, filename):
|
||||
return send_file(f, as_attachment=True, attachment_filename=filename)
|
||||
|
||||
|
||||
def get_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
||||
def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
||||
response = {"code": code, "message": message, "data": data}
|
||||
return jsonify(response)
|
||||
|
||||
@ -204,7 +202,7 @@ def apikey_required(func):
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return build_error_result(
|
||||
message='API-KEY is invalid!', code=RetCode.FORBIDDEN
|
||||
message='API-KEY is invalid!', code=settings.RetCode.FORBIDDEN
|
||||
)
|
||||
kwargs['tenant_id'] = objs[0].tenant_id
|
||||
return func(*args, **kwargs)
|
||||
@ -212,14 +210,14 @@ def apikey_required(func):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def build_error_result(code=RetCode.FORBIDDEN, message='success'):
|
||||
def build_error_result(code=settings.RetCode.FORBIDDEN, message='success'):
|
||||
response = {"code": code, "message": message}
|
||||
response = jsonify(response)
|
||||
response.status_code = code
|
||||
return response
|
||||
|
||||
|
||||
def construct_response(code=RetCode.SUCCESS,
|
||||
def construct_response(code=settings.RetCode.SUCCESS,
|
||||
message='success', data=None, auth=None):
|
||||
result_dict = {"code": code, "message": message, "data": data}
|
||||
response_dict = {}
|
||||
@ -239,7 +237,7 @@ def construct_response(code=RetCode.SUCCESS,
|
||||
return response
|
||||
|
||||
|
||||
def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
||||
def construct_result(code=settings.RetCode.DATA_ERROR, message='data is missing'):
|
||||
import re
|
||||
result_dict = {"code": code, "message": re.sub(r"rag", "seceum", message, flags=re.IGNORECASE)}
|
||||
response = {}
|
||||
@ -251,7 +249,7 @@ def construct_result(code=RetCode.DATA_ERROR, message='data is missing'):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def construct_json_result(code=RetCode.SUCCESS, message='success', data=None):
|
||||
def construct_json_result(code=settings.RetCode.SUCCESS, message='success', data=None):
|
||||
if data is None:
|
||||
return jsonify({"code": code, "message": message})
|
||||
else:
|
||||
@ -262,12 +260,12 @@ def construct_error_response(e):
|
||||
logging.exception(e)
|
||||
try:
|
||||
if e.code == 401:
|
||||
return construct_json_result(code=RetCode.UNAUTHORIZED, message=repr(e))
|
||||
return construct_json_result(code=settings.RetCode.UNAUTHORIZED, message=repr(e))
|
||||
except BaseException:
|
||||
pass
|
||||
if len(e.args) > 1:
|
||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
return construct_json_result(code=RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e.args[0]), data=e.args[1])
|
||||
return construct_json_result(code=settings.RetCode.EXCEPTION_ERROR, message=repr(e))
|
||||
|
||||
|
||||
def token_required(func):
|
||||
@ -280,7 +278,7 @@ def token_required(func):
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Token is not valid!', code=RetCode.AUTHENTICATION_ERROR
|
||||
data=False, message='Token is not valid!', code=settings.RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
kwargs['tenant_id'] = objs[0].tenant_id
|
||||
return func(*args, **kwargs)
|
||||
@ -288,7 +286,7 @@ def token_required(func):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def get_result(code=RetCode.SUCCESS, message="", data=None):
|
||||
def get_result(code=settings.RetCode.SUCCESS, message="", data=None):
|
||||
if code == 0:
|
||||
if data is not None:
|
||||
response = {"code": code, "data": data}
|
||||
@ -299,7 +297,7 @@ def get_result(code=RetCode.SUCCESS, message="", data=None):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
def get_error_data_result(message='Sorry! Data missing!', code=RetCode.DATA_ERROR,
|
||||
def get_error_data_result(message='Sorry! Data missing!', code=settings.RetCode.DATA_ERROR,
|
||||
):
|
||||
import re
|
||||
result_dict = {
|
||||
|
@ -24,7 +24,7 @@ import numpy as np
|
||||
from timeit import default_timer as timer
|
||||
from pypdf import PdfReader as pdf2_read
|
||||
|
||||
from api.settings import LIGHTEN
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
|
||||
from rag.nlp import rag_tokenizer
|
||||
@ -41,7 +41,7 @@ class RAGFlowPdfParser:
|
||||
self.tbl_det = TableStructureRecognizer()
|
||||
|
||||
self.updown_cnt_mdl = xgb.Booster()
|
||||
if not LIGHTEN:
|
||||
if not settings.LIGHTEN:
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
|
@ -252,13 +252,13 @@ if __name__ == "__main__":
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from api import settings
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||
|
||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||
info = {
|
||||
"input_text": docs,
|
||||
"entity_specs": "organization, person",
|
||||
|
@ -30,14 +30,14 @@ if __name__ == "__main__":
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.settings import retrievaler
|
||||
from api import settings
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||
|
||||
ex = GraphExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in
|
||||
retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
||||
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=6, fields=["content_with_weight"])]
|
||||
graph = ex(docs)
|
||||
|
||||
er = EntityResolution(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
|
@ -23,7 +23,7 @@ from collections import defaultdict
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.settings import retrievaler, docStoreConn
|
||||
from api import settings
|
||||
from api.utils import get_uuid
|
||||
from rag.nlp import tokenize, search
|
||||
from ranx import evaluate
|
||||
@ -52,7 +52,7 @@ class Benchmark:
|
||||
run = defaultdict(dict)
|
||||
query_list = list(qrels.keys())
|
||||
for query in query_list:
|
||||
ranks = retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
ranks = settings.retrievaler.retrieval(query, self.embd_mdl, self.tenant_id, [self.kb.id], 1, 30,
|
||||
0.0, self.vector_similarity_weight)
|
||||
if len(ranks["chunks"]) == 0:
|
||||
print(f"deleted query: {query}")
|
||||
@ -81,9 +81,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
@ -118,13 +118,13 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
docs = []
|
||||
|
||||
if docs:
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
settings.docStoreConn.insert(docs, self.index_name, self.kb_id)
|
||||
return qrels, texts
|
||||
|
||||
def trivia_qa_index(self, file_path, index_name):
|
||||
@ -159,12 +159,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs,self.index_name)
|
||||
settings.docStoreConn.insert(docs,self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def miracl_index(self, file_path, corpus_path, index_name):
|
||||
@ -214,12 +214,12 @@ class Benchmark:
|
||||
docs_count += len(docs)
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
docs = []
|
||||
|
||||
docs, vector_size = self.embedding(docs)
|
||||
self.init_index(vector_size)
|
||||
docStoreConn.insert(docs, self.index_name)
|
||||
settings.docStoreConn.insert(docs, self.index_name)
|
||||
return qrels, texts
|
||||
|
||||
def save_results(self, qrels, run, texts, dataset, file_path):
|
||||
|
@ -28,7 +28,7 @@ from openai import OpenAI
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
from api.settings import LIGHTEN
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import google.generativeai as genai
|
||||
@ -60,7 +60,7 @@ class DefaultEmbedding(Base):
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not LIGHTEN and not DefaultEmbedding._model:
|
||||
if not settings.LIGHTEN and not DefaultEmbedding._model:
|
||||
with DefaultEmbedding._model_lock:
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
@ -248,7 +248,7 @@ class FastEmbed(Base):
|
||||
threads: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not LIGHTEN and not FastEmbed._model:
|
||||
if not settings.LIGHTEN and not FastEmbed._model:
|
||||
from fastembed import TextEmbedding
|
||||
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
||||
|
||||
@ -294,7 +294,7 @@ class YoudaoEmbed(Base):
|
||||
_client = None
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
||||
if not LIGHTEN and not YoudaoEmbed._client:
|
||||
if not settings.LIGHTEN and not YoudaoEmbed._client:
|
||||
from BCEmbedding import EmbeddingModel as qanthing
|
||||
try:
|
||||
logging.info("LOADING BCE...")
|
||||
|
@ -23,7 +23,7 @@ import os
|
||||
from abc import ABC
|
||||
import numpy as np
|
||||
|
||||
from api.settings import LIGHTEN
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_home_cache_dir
|
||||
from rag.utils import num_tokens_from_string, truncate
|
||||
import json
|
||||
@ -57,7 +57,7 @@ class DefaultRerank(Base):
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not LIGHTEN and not DefaultRerank._model:
|
||||
if not settings.LIGHTEN and not DefaultRerank._model:
|
||||
import torch
|
||||
from FlagEmbedding import FlagReranker
|
||||
with DefaultRerank._model_lock:
|
||||
@ -121,7 +121,7 @@ class YoudaoRerank(DefaultRerank):
|
||||
_model_lock = threading.Lock()
|
||||
|
||||
def __init__(self, key=None, model_name="maidalun1020/bce-reranker-base_v1", **kwargs):
|
||||
if not LIGHTEN and not YoudaoRerank._model:
|
||||
if not settings.LIGHTEN and not YoudaoRerank._model:
|
||||
from BCEmbedding import RerankerModel
|
||||
with YoudaoRerank._model_lock:
|
||||
if not YoudaoRerank._model:
|
||||
|
@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import sys
|
||||
from api.utils.log_utils import initRootLogger
|
||||
|
||||
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
||||
initRootLogger(f"task_executor_{CONSUMER_NO}")
|
||||
for module in ["pdfminer"]:
|
||||
@ -49,9 +50,10 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.settings import retrievaler, docStoreConn
|
||||
from api import settings
|
||||
from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
knowledge_graph, email
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME
|
||||
@ -88,6 +90,7 @@ PENDING_TASKS = 0
|
||||
HEAD_CREATED_AT = ""
|
||||
HEAD_DETAIL = ""
|
||||
|
||||
|
||||
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
||||
global PAYLOAD
|
||||
if prog is not None and prog < 0:
|
||||
@ -171,7 +174,8 @@ def build(row):
|
||||
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
||||
except TimeoutError:
|
||||
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
||||
logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||
logging.exception(
|
||||
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(row["location"], row["name"]))
|
||||
return
|
||||
except Exception as e:
|
||||
if re.search("(No such file|not found)", str(e)):
|
||||
@ -188,7 +192,7 @@ def build(row):
|
||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, row["location"], row["name"]))
|
||||
except Exception as e:
|
||||
callback(-1, "Internal server error while chunking: %s" %
|
||||
str(e).replace("'", ""))
|
||||
str(e).replace("'", ""))
|
||||
logging.exception("Chunking {}/{} got exception".format(row["location"], row["name"]))
|
||||
return
|
||||
|
||||
@ -226,7 +230,8 @@ def build(row):
|
||||
STORAGE_IMPL.put(row["kb_id"], d["id"], output_buffer.getvalue())
|
||||
el += timer() - st
|
||||
except Exception:
|
||||
logging.exception("Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
||||
logging.exception(
|
||||
"Saving image of chunk {}/{}/{} got exception".format(row["location"], row["name"], d["_id"]))
|
||||
|
||||
d["img_id"] = "{}-{}".format(row["kb_id"], d["id"])
|
||||
del d["image"]
|
||||
@ -241,7 +246,7 @@ def build(row):
|
||||
d["important_kwd"] = keyword_extraction(chat_mdl, d["content_with_weight"],
|
||||
row["parser_config"]["auto_keywords"]).split(",")
|
||||
d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"]))
|
||||
callback(msg="Keywords generation completed in {:.2f}s".format(timer()-st))
|
||||
callback(msg="Keywords generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
if row["parser_config"].get("auto_questions", 0):
|
||||
st = timer()
|
||||
@ -255,14 +260,14 @@ def build(row):
|
||||
d["content_ltks"] += " " + qst
|
||||
if "content_sm_ltks" in d:
|
||||
d["content_sm_ltks"] += " " + rag_tokenizer.fine_grained_tokenize(qst)
|
||||
callback(msg="Question generation completed in {:.2f}s".format(timer()-st))
|
||||
callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||
return settings.docStoreConn.createIdx(idxnm, row["kb_id"], vector_size)
|
||||
|
||||
|
||||
def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
@ -313,7 +318,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
||||
vector_size = len(vts[0])
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
chunks = []
|
||||
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm]):
|
||||
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm]):
|
||||
chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))
|
||||
|
||||
raptor = Raptor(
|
||||
@ -384,7 +390,8 @@ def main():
|
||||
# TODO: exception handler
|
||||
## set_progress(r["did"], -1, "ERROR: ")
|
||||
callback(
|
||||
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks), timer() - st)
|
||||
msg="Finished slicing files ({} chunks in {:.2f}s). Start to embedding the content.".format(len(cks),
|
||||
timer() - st)
|
||||
)
|
||||
st = timer()
|
||||
try:
|
||||
@ -403,18 +410,18 @@ def main():
|
||||
es_r = ""
|
||||
es_bulk_size = 4
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
es_r = docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
es_r = settings.docStoreConn.insert(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
if b % 128 == 0:
|
||||
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
||||
|
||||
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
||||
if es_r:
|
||||
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!")
|
||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
logging.error('Insert chunk error: ' + str(es_r))
|
||||
else:
|
||||
if TaskService.do_cancel(r["id"]):
|
||||
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
settings.docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"])
|
||||
continue
|
||||
callback(msg="Indexing elapsed in {:.2f}s.".format(timer() - st))
|
||||
callback(1., "Done!")
|
||||
@ -435,7 +442,7 @@ def report_status():
|
||||
if PENDING_TASKS > 0:
|
||||
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME)
|
||||
if head_info is not None:
|
||||
seconds = int(head_info[0].split("-")[0])/1000
|
||||
seconds = int(head_info[0].split("-")[0]) / 1000
|
||||
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat()
|
||||
HEAD_DETAIL = head_info[1]
|
||||
|
||||
@ -452,7 +459,7 @@ def report_status():
|
||||
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
||||
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
||||
|
||||
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30)
|
||||
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30)
|
||||
if expired > 0:
|
||||
REDIS_CONN.zpopmin(CONSUMER_NAME, expired)
|
||||
except Exception:
|
||||
|
Loading…
x
Reference in New Issue
Block a user