Fix retrieval API error and add multi-kb search (#1928)

### What problem does this PR solve?
Type of change
 Bug Fix (Import necessary class for retrieval API )
 New Feature (Add multi-KB search to retrieval API)
This commit is contained in:
wwwlll 2024-08-13 15:30:51 +08:00 committed by GitHub
parent 7a08e91909
commit 06700850df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,9 +18,10 @@ import os
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from flask import request, Response from flask import request, Response
from api.db.services.llm_service import TenantLLMService
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db import FileType, ParserType, FileSource from api.db import FileType, LLMType, ParserType, FileSource
from api.db.db_models import APIToken, API4Conversation, Task, File from api.db.db_models import APIToken, API4Conversation, Task, File
from api.db.services import duplicate_name from api.db.services import duplicate_name
from api.db.services.api_service import APITokenService, API4ConversationService from api.db.services.api_service import APITokenService, API4ConversationService
@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
from itsdangerous import URLSafeTimedSerializer from itsdangerous import URLSafeTimedSerializer
from api.utils.file_utils import filename_type, thumbnail from api.utils.file_utils import filename_type, thumbnail
from rag.nlp import keyword_extraction
from rag.utils.minio_conn import MINIO from rag.utils.minio_conn import MINIO
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
@ -694,7 +696,7 @@ def retrieval():
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
req = request.json req = request.json
kb_id = req.get("kb_id") kb_ids = req.get("kb_id",[])
doc_ids = req.get("doc_ids", []) doc_ids = req.get("doc_ids", [])
question = req.get("question") question = req.get("question")
page = int(req.get("page", 1)) page = int(req.get("page", 1))
@ -704,32 +706,30 @@ def retrieval():
top = int(req.get("top_k", 1024)) top = int(req.get("top_k", 1024))
try: try:
e, kb = KnowledgebaseService.get_by_id(kb_id) kbs = KnowledgebaseService.get_by_ids(kb_ids)
if not e: embd_nms = list(set([kb.embd_id for kb in kbs]))
return get_data_error_result(retmsg="Knowledgebase not found!") if len(embd_nms) != 1:
return get_json_result(
data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR)
embd_mdl = TenantLLMService.model_instance( embd_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id)
rerank_mdl = None rerank_mdl = None
if req.get("rerank_id"): if req.get("rerank_id"):
rerank_mdl = TenantLLMService.model_instance( rerank_mdl = TenantLLMService.model_instance(
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold, vector_similarity_weight, top,
similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl)
doc_ids, rerank_mdl=rerank_mdl)
for c in ranks["chunks"]: for c in ranks["chunks"]:
if "vector" in c: if "vector" in c:
del c["vector"] del c["vector"]
return get_json_result(data=ranks) return get_json_result(data=ranks)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
retcode=RetCode.DATA_ERROR) retcode=RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)