From 44801df8f82f89777bd9a4ad18ad23ee8062fff7 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:10:51 +0800 Subject: [PATCH] fix score threshold limit be None (#6900) --- .../app/app_config/easy_ui_based_app/dataset/manager.py | 3 ++- api/core/app/app_config/entities.py | 3 ++- api/core/rag/retrieval/dataset_retrieval.py | 3 ++- .../unit_tests/services/workflow/test_workflow_converter.py | 6 ++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 13da5514d1..ec17db5f06 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -91,7 +91,8 @@ class DatasetConfigManager: top_k=dataset_configs.get('top_k', 4), score_threshold=dataset_configs.get('score_threshold'), reranking_model=dataset_configs.get('reranking_model'), - weights=dataset_configs.get('weights') + weights=dataset_configs.get('weights'), + reranking_enabled=dataset_configs.get('reranking_enabled', True), ) ) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 9133a35c08..a490ddd670 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -158,10 +158,11 @@ class DatasetRetrieveConfigEntity(BaseModel): retrieve_strategy: RetrieveStrategy top_k: Optional[int] = None - score_threshold: Optional[float] = None + score_threshold: Optional[float] = .0 rerank_mode: Optional[str] = 'reranking_model' reranking_model: Optional[dict] = None weights: Optional[dict] = None + reranking_enabled: Optional[bool] = True diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index d51ea2942a..a69fcffbb4 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -138,6 +138,7 @@ class DatasetRetrieval: retrieve_config.rerank_mode, retrieve_config.reranking_model, retrieve_config.weights, + retrieve_config.reranking_enabled, message_id, ) @@ -606,7 +607,7 @@ class DatasetRetrieval: top_k: int, score_threshold: float) -> list[Document]: filter_documents = [] for document in all_documents: - if document.metadata['score'] >= score_threshold: + if score_threshold and document.metadata['score'] >= score_threshold: filter_documents.append(document) if not filter_documents: return [] diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 29d55df8c3..f589cd2097 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -208,7 +208,8 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot(): reranking_model={ 'reranking_provider_name': 'cohere', 'reranking_model_name': 'rerank-english-v2.0' - } + }, + reranking_enabled=True ) ) @@ -251,7 +252,8 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app(): reranking_model={ 'reranking_provider_name': 'cohere', 'reranking_model_name': 'rerank-english-v2.0' - } + }, + reranking_enabled=True ) )