From 7e611ffbf3e018b98803123319d2e8bae81eaccd Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Sat, 14 Sep 2024 21:48:44 +0800 Subject: [PATCH] multi-retrival use dataset's top-k (#8416) --- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../tool/dataset_retriever/dataset_multi_retriever_tool.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 124c58f0fe..286ecd4c03 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -426,7 +426,7 @@ class DatasetRetrieval: retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=top_k, + top_k=retrieval_model.get("top_k") or 2, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0, diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 6073b8e92e..ab7b40a253 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -165,7 +165,10 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=retrieval_model.get("top_k") or 2, ) if documents: all_documents.extend(documents) @@ -176,7 +179,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): retrieval_method=retrieval_model["search_method"], dataset_id=dataset.id, query=query, - top_k=self.top_k, + top_k=retrieval_model.get("top_k") or 2, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] else 0.0,