fix score threshold limit be None (#6900)

This commit is contained in:
Jyong 2024-08-02 12:10:51 +08:00 committed by GitHub
parent 56af1a0adf
commit 44801df8f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 5 deletions

View File

@ -91,7 +91,8 @@ class DatasetConfigManager:
top_k=dataset_configs.get('top_k', 4), top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'), score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model'), reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights') weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
) )
) )

View File

@ -158,10 +158,11 @@ class DatasetRetrieveConfigEntity(BaseModel):
retrieve_strategy: RetrieveStrategy retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None top_k: Optional[int] = None
score_threshold: Optional[float] = None score_threshold: Optional[float] = .0
rerank_mode: Optional[str] = 'reranking_model' rerank_mode: Optional[str] = 'reranking_model'
reranking_model: Optional[dict] = None reranking_model: Optional[dict] = None
weights: Optional[dict] = None weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True

View File

@ -138,6 +138,7 @@ class DatasetRetrieval:
retrieve_config.rerank_mode, retrieve_config.rerank_mode,
retrieve_config.reranking_model, retrieve_config.reranking_model,
retrieve_config.weights, retrieve_config.weights,
retrieve_config.reranking_enabled,
message_id, message_id,
) )
@ -606,7 +607,7 @@ class DatasetRetrieval:
top_k: int, score_threshold: float) -> list[Document]: top_k: int, score_threshold: float) -> list[Document]:
filter_documents = [] filter_documents = []
for document in all_documents: for document in all_documents:
if document.metadata['score'] >= score_threshold: if score_threshold and document.metadata['score'] >= score_threshold:
filter_documents.append(document) filter_documents.append(document)
if not filter_documents: if not filter_documents:
return [] return []

View File

@ -208,7 +208,8 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot():
reranking_model={ reranking_model={
'reranking_provider_name': 'cohere', 'reranking_provider_name': 'cohere',
'reranking_model_name': 'rerank-english-v2.0' 'reranking_model_name': 'rerank-english-v2.0'
} },
reranking_enabled=True
) )
) )
@ -251,7 +252,8 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app():
reranking_model={ reranking_model={
'reranking_provider_name': 'cohere', 'reranking_provider_name': 'cohere',
'reranking_model_name': 'rerank-english-v2.0' 'reranking_model_name': 'rerank-english-v2.0'
} },
reranking_enabled=True
) )
) )