diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index abbf4a35a4..3932e90042 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -28,7 +28,7 @@ class RetrievalService: @classmethod def retrieve(cls, retrival_method: str, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float] = .0, - reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = None, + reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = 'reranking_model', weights: Optional[dict] = None): dataset = db.session.query(Dataset).filter( Dataset.id == dataset_id @@ -36,10 +36,6 @@ class RetrievalService: if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: return [] all_documents = [] - keyword_search_documents = [] - embedding_search_documents = [] - full_text_search_documents = [] - hybrid_search_documents = [] threads = [] exceptions = [] # retrieval_model source with keyword diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 2c5d920a9a..dc1f1ada11 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -278,6 +278,7 @@ class DatasetRetrieval: query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, + reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'), weights=retrieval_model_config.get('weights', None), ) self._on_query(query, [dataset_id], app_id, user_from, user_id) @@ -431,10 +432,12 @@ class DatasetRetrieval: dataset_id=dataset.id, query=query, top_k=top_k, - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] + reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 1a0933af16..7cb7c033bb 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -177,10 +177,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): dataset_id=dataset.id, query=query, top_k=self.top_k, - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] + reranking_model=retrieval_model.get('reranking_model', None) if retrieval_model['reranking_enable'] else None, + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 397ff7966e..de8ff7ad38 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -14,6 +14,7 @@ default_retrieval_model = { 'reranking_provider_name': '', 'reranking_model_name': '' }, + 'reranking_mode': 'reranking_model', 'top_k': 2, 'score_threshold_enabled': False } @@ -71,14 +72,15 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): else: if self.top_k > 0: # retrieval source - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), dataset_id=dataset.id, query=query, top_k=self.top_k, - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None, + reranking_model=retrieval_model.get('reranking_model', None), + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), ) else: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 69274dff09..0e072a3e21 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -42,11 +42,11 @@ class HitTestingService: dataset_id=dataset.id, query=cls.escape_query_for_search(query), top_k=retrieval_model.get('top_k', 2), - score_threshold=retrieval_model['score_threshold'] + score_threshold=retrieval_model.get('score_threshold', .0) if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model['reranking_model'] - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode', None), + reranking_model=retrieval_model.get('reranking_model', None), + reranking_mode=retrieval_model.get('reranking_mode') + if retrieval_model.get('reranking_mode') else 'reranking_model', weights=retrieval_model.get('weights', None), )