mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 22:38:58 +08:00
API: retrieval api (#1763)
### What problem does this PR solve? Add retrieval api on a specific knowledge base  https://github.com/infiniflow/ragflow/issues/1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
da11a20c92
commit
b9a50ef4b8
@ -20,7 +20,7 @@ from datetime import datetime, timedelta
|
|||||||
from flask import request, Response
|
from flask import request, Response
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from api.db import FileType, ParserType, FileSource
|
from api.db import FileType, ParserType, FileSource, LLMType
|
||||||
from api.db.db_models import APIToken, API4Conversation, Task, File
|
from api.db.db_models import APIToken, API4Conversation, Task, File
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||||
@ -29,6 +29,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.db.services.llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import queue_tasks, TaskService
|
from api.db.services.task_service import queue_tasks, TaskService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.settings import RetCode, retrievaler
|
from api.settings import RetCode, retrievaler
|
||||||
@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge
|
|||||||
from itsdangerous import URLSafeTimedSerializer
|
from itsdangerous import URLSafeTimedSerializer
|
||||||
|
|
||||||
from api.utils.file_utils import filename_type, thumbnail
|
from api.utils.file_utils import filename_type, thumbnail
|
||||||
|
from rag.nlp import keyword_extraction
|
||||||
from rag.utils.minio_conn import MINIO
|
from rag.utils.minio_conn import MINIO
|
||||||
|
|
||||||
|
|
||||||
@ -587,3 +589,55 @@ def completion_faq():
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/retrieval', methods=['POST'])
|
||||||
|
@validate_request("kb_id", "question")
|
||||||
|
def retrieval():
|
||||||
|
token = request.headers.get('Authorization').split()[1]
|
||||||
|
objs = APIToken.query(token=token)
|
||||||
|
if not objs:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR)
|
||||||
|
|
||||||
|
req = request.json
|
||||||
|
kb_id = req.get("kb_id")
|
||||||
|
doc_ids = req.get("doc_ids", [])
|
||||||
|
question = req.get("question")
|
||||||
|
page = int(req.get("page", 1))
|
||||||
|
size = int(req.get("size", 30))
|
||||||
|
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||||
|
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||||
|
top = int(req.get("top_k", 1024))
|
||||||
|
|
||||||
|
try:
|
||||||
|
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(retmsg="Knowledgebase not found!")
|
||||||
|
|
||||||
|
embd_mdl = TenantLLMService.model_instance(
|
||||||
|
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
|
rerank_mdl = None
|
||||||
|
if req.get("rerank_id"):
|
||||||
|
rerank_mdl = TenantLLMService.model_instance(
|
||||||
|
kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"])
|
||||||
|
|
||||||
|
if req.get("keyword", False):
|
||||||
|
chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT)
|
||||||
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
|
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size,
|
||||||
|
similarity_threshold, vector_similarity_weight, top,
|
||||||
|
doc_ids, rerank_mdl=rerank_mdl)
|
||||||
|
for c in ranks["chunks"]:
|
||||||
|
if "vector" in c:
|
||||||
|
del c["vector"]
|
||||||
|
|
||||||
|
return get_json_result(data=ranks)
|
||||||
|
except Exception as e:
|
||||||
|
if str(e).find("not_found") > 0:
|
||||||
|
return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!',
|
||||||
|
retcode=RetCode.DATA_ERROR)
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user