diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 815c671ea0..8c6c3959d7 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -40,7 +40,7 @@ default_retrieval_model = { 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } class OrchestratorRuleParser: @@ -220,8 +220,8 @@ class OrchestratorRuleParser: # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) score_threshold = None - score_threshold_enable = retrieval_model_config.get("score_threshold_enable") - if score_threshold_enable: + score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") + if score_threshold_enabled: score_threshold = retrieval_model_config.get("score_threshold") tool = DatasetRetrieverTool.from_dataset( @@ -239,7 +239,7 @@ class OrchestratorRuleParser: dataset_ids=dataset_ids, tenant_id=kwargs['tenant_id'], top_k=dataset_configs.get('top_k', 2), - score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None, + score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None, callbacks=[DatasetToolCallbackHandler(conversation_message_task)], conversation_message_task=conversation_message_task, return_resource=return_resource, diff --git a/api/core/tool/dataset_multi_retriever_tool.py b/api/core/tool/dataset_multi_retriever_tool.py index 5cf120b63b..42dab1ca98 100644 --- a/api/core/tool/dataset_multi_retriever_tool.py +++ b/api/core/tool/dataset_multi_retriever_tool.py @@ -24,7 +24,7 @@ default_retrieval_model = { 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } @@ -216,7 +216,7 @@ class DatasetMultiRetrieverTool(BaseTool): 'embeddings': embeddings, 'score_threshold': retrieval_model[ 'score_threshold'] if retrieval_model[ - 'score_threshold_enable'] else None, + 'score_threshold_enabled'] else None, 'top_k': self.top_k, 'reranking_model': retrieval_model[ 'reranking_model'] if retrieval_model[ diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 822a6562be..f76a8904b9 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -25,7 +25,7 @@ default_retrieval_model = { 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } @@ -110,7 +110,7 @@ class DatasetRetrieverTool(BaseTool): 'query': query, 'top_k': self.top_k, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enable'] else None, + 'score_threshold_enabled'] else None, 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ 'reranking_enable'] else None, 'all_documents': documents, @@ -129,7 +129,7 @@ class DatasetRetrieverTool(BaseTool): 'search_method': retrieval_model['search_method'], 'embeddings': embeddings, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ - 'score_threshold_enable'] else None, + 'score_threshold_enabled'] else None, 'top_k': self.top_k, 'reranking_model': retrieval_model['reranking_model'] if retrieval_model[ 'reranking_enable'] else None, @@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool): model_name=retrieval_model['reranking_model']['reranking_model_name'] ) documents = hybrid_rerank.rerank(query, documents, - retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, self.top_k) else: documents = [] diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index d7be65be01..1382871ae7 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -22,7 +22,7 @@ dataset_retrieval_model_fields = { 'reranking_enable': fields.Boolean, 'reranking_model': fields.Nested(reranking_model_fields), 'top_k': fields.Integer, - 'score_threshold_enable': fields.Boolean, + 'score_threshold_enabled': fields.Boolean, 'score_threshold': fields.Float } diff --git a/api/models/dataset.py b/api/models/dataset.py index 5fbf035f84..a03a5cd38d 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -104,7 +104,7 @@ class Dataset(db.Model): 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } return self.retrieval_model if self.retrieval_model else default_retrieval_model diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index defe539ae9..4093a8f72f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -485,7 +485,7 @@ class DocumentService: 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model @@ -769,7 +769,7 @@ class DocumentService: 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } retrieval_model = default_retrieval_model # save dataset diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 831a37d670..9de72ad31f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -25,7 +25,7 @@ default_retrieval_model = { 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False } class HitTestingService: @@ -64,7 +64,7 @@ class HitTestingService: 'dataset_id': str(dataset.id), 'query': query, 'top_k': retrieval_model['top_k'], - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, 'all_documents': all_documents, 'search_method': retrieval_model['search_method'], @@ -81,7 +81,7 @@ class HitTestingService: 'query': query, 'search_method': retrieval_model['search_method'], 'embeddings': embeddings, - 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, 'top_k': retrieval_model['top_k'], 'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None, 'all_documents': all_documents @@ -99,7 +99,7 @@ class HitTestingService: model_name=retrieval_model['reranking_model']['reranking_model_name'] ) all_documents = hybrid_rerank.rerank(query, all_documents, - retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, + retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None, retrieval_model['top_k']) end = time.perf_counter() diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py index f12533f2b0..fbd96dfe20 100644 --- a/api/services/retrieval_service.py +++ b/api/services/retrieval_service.py @@ -15,7 +15,7 @@ default_retrieval_model = { 'reranking_model_name': '' }, 'top_k': 2, - 'score_threshold_enable': False + 'score_threshold_enabled': False }