diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 13da5514d1..ec17db5f06 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -91,7 +91,8 @@ class DatasetConfigManager: top_k=dataset_configs.get('top_k', 4), score_threshold=dataset_configs.get('score_threshold'), reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights') + weights=dataset_configs.get('weights'), + reranking_enabled=dataset_configs.get('reranking_enabled', True), ) ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 9133a35c08..a490ddd670 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -158,10 +158,11 @@ class DatasetRetrieveConfigEntity(BaseModel): retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = None + score_threshold: Optional[float] = .0 rerank_mode: Optional[str] = 'reranking_model' reranking_model: Optional[dict] = None weights: Optional[dict] = None + reranking_enabled: Optional[bool] = True diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index d51ea2942a..a69fcffbb4 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -138,6 +138,7 @@ class DatasetRetrieval: retrieve_config.rerank_mode, retrieve_config.reranking_model, retrieve_config.weights, + retrieve_config.reranking_enabled, message_id, ) @@ -606,7 +607,7 @@ class DatasetRetrieval: top_k: int, score_threshold: float) -> list[Document]: filter_documents = [] for document in all_documents: - if document.metadata['score'] >= score_threshold: + if score_threshold and document.metadata['score'] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 29d55df8c3..f589cd2097 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -208,7 +208,8 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): reranking_model={ 'reranking_provider_name': 'cohere', 'reranking_model_name': 'rerank-english-v2.0' - } + }, + reranking_enabled=True ) ) @@ -251,7 +252,8 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): reranking_model={ 'reranking_provider_name': 'cohere', 'reranking_model_name': 'rerank-english-v2.0' - } + }, + reranking_enabled=True ) )