From bfa652f2d041ab1bbf4da25cc90f2cd11c46becf Mon Sep 17 00:00:00 2001 From: Will Date: Wed, 7 May 2025 14:52:09 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20metadata=20filtering=20condition=20varia?= =?UTF-8?q?ble=20unassigned;=20fix=20External=20K=E2=80=A6=20(#19208)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/controllers/console/datasets/external.py | 2 ++ api/core/agent/base_agent_runner.py | 2 ++ api/core/agent/cot_agent_runner.py | 7 ---- api/core/agent/fc_agent_runner.py | 7 ---- api/core/rag/datasource/retrieval_service.py | 18 ++++++++-- api/core/rag/retrieval/dataset_retrieval.py | 36 +++++++++++++------ .../dataset_retriever_tool.py | 36 +++++++++++++++---- .../tools/utils/dataset_retriever_tool.py | 4 +++ .../knowledge_retrieval_node.py | 27 +++++++++----- api/services/hit_testing_service.py | 2 ++ 10 files changed, 101 insertions(+), 40 deletions(-) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 30b7f63aab..cf9081e154 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource): parser = reqparse.RequestParser() parser.add_argument("query", type=str, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource): query=args["query"], account=current_user, external_retrieval_model=args["external_retrieval_model"], + metadata_filtering_conditions=args["metadata_filtering_conditions"], ) return response diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e648613605..6998e4d29a 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner): return_resource=app_config.additional_features.show_retrieve_source, invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback, + user_id=user_id, + inputs=cast(dict, application_generate_entity.inputs), ) # get how many agent thoughts have been created self.agent_thought_count = ( diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index de3b7e1ad7..feb8abf6ef 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tool_instances, prompt_messages_tools = self._init_prompt_tools() self._prompt_messages_tools = prompt_messages_tools - # fix metadata filter not work - if app_config.dataset is not None: - metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions - for key, dataset_retriever_tool in tool_instances.items(): - if hasattr(dataset_retriever_tool, "retrieval_tool"): - dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions - function_call_state = True llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} final_answer = "" diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 874bd6b93b..a1110e7709 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() - # fix metadata filter not work - if app_config.dataset is not None: - metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions - for key, dataset_retriever_tool in tool_instances.items(): - if hasattr(dataset_retriever_tool, "retrieval_tool"): - dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions - assert app_config.agent iteration_step = 1 diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 46a5330bdb..01f74b4a22 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode @@ -119,12 +120,25 @@ class RetrievalService: return all_documents @classmethod - def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): + def external_retrieve( + cls, + dataset_id: str, + query: str, + external_retrieval_model: Optional[dict] = None, + metadata_filtering_conditions: Optional[dict] = None, + ): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] + metadata_condition = ( + MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None + ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - dataset.tenant_id, dataset_id, query, external_retrieval_model or {} + dataset.tenant_id, + dataset_id, + query, + external_retrieval_model or {}, + metadata_condition=metadata_condition, ) return all_documents diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index b1565f10f2..9216b31b8e 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -149,7 +149,7 @@ class DatasetRetrieval: else: inputs = {} available_datasets_ids = [dataset.id for dataset in available_datasets] - metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( + metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition( available_datasets_ids, query, tenant_id, @@ -649,6 +649,8 @@ class DatasetRetrieval: return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, + user_id: str, + inputs: dict, ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset @@ -706,6 +708,9 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), + retrieve_config=retrieve_config, + user_id=user_id, + inputs=inputs, ) tools.append(tool) @@ -826,7 +831,7 @@ class DatasetRetrieval: ) return filter_documents[:top_k] if top_k else filter_documents - def _get_metadata_filter_condition( + def get_metadata_filter_condition( self, dataset_ids: list, query: str, @@ -876,20 +881,31 @@ class DatasetRetrieval: ) elif metadata_filtering_mode == "manual": if metadata_filtering_conditions: - metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) + conditions = [] for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore metadata_name = condition.name expected_value = condition.value - if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): if isinstance(expected_value, str): expected_value = self._replace_metadata_filter_value(expected_value, inputs) - filters = self._process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, ) + ) + filters = self._process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=metadata_filtering_conditions.logical_operator, + conditions=conditions, + ) else: raise ValueError("Invalid metadata filtering mode") if filters: diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index dcd3d080f3..ed97b44f95 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,11 +1,12 @@ -from typing import Any +from typing import Any, Optional, cast from pydantic import BaseModel, Field +from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.context_entities import DocumentContext -from core.rag.entities.metadata_entities import MetadataCondition from core.rag.models.document import Document as RetrievalDocument +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db @@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): args_schema: type[BaseModel] = DatasetRetrieverToolInput description: str = "use this to retrieve a dataset. " dataset_id: str - metadata_filtering_conditions: MetadataCondition + user_id: Optional[str] = None + retrieve_config: DatasetRetrieveConfigEntity + inputs: dict @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): @@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): tenant_id=dataset.tenant_id, dataset_id=dataset.id, description=description, - metadata_filtering_conditions=MetadataCondition(), **kwargs, ) @@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return "" for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) + dataset_retrieval = DatasetRetrieval() + metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( + [dataset.id], + query, + self.tenant_id, + self.user_id or "unknown", + cast(str, self.retrieve_config.metadata_filtering_mode), + cast(ModelConfig, self.retrieve_config.metadata_model_config), + self.retrieve_config.metadata_filtering_conditions, + self.inputs, + ) + if metadata_filter_document_ids: + document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) + else: + document_ids_filter = None if dataset.provider == "external": results = [] external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( @@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): dataset_id=dataset.id, query=query, external_retrieval_parameters=dataset.retrieval_model, - metadata_condition=self.metadata_filtering_conditions, + metadata_condition=metadata_condition, ) for external_document in external_documents: document = RetrievalDocument( @@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join([item.page_content for item in results])) else: + if metadata_condition and not document_ids_filter: + return "" # get retrieval model , if the model is not setting , using default retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + retrieval_method="keyword_search", + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + document_ids_filter=document_ids_filter, ) return str("\n".join([document.page_content for document in documents])) else: @@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights"), + document_ids_filter=document_ids_filter, ) else: documents = [] diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index b73dec4ebc..ec0575f6c3 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool): return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, + user_id: str, + inputs: dict, ) -> list["DatasetRetrieverTool"]: """ get dataset tool @@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool): return_resource=return_resource, invoke_from=invoke_from, hit_callback=hit_callback, + user_id=user_id, + inputs=inputs, ) if retrieval_tools is None or len(retrieval_tools) == 0: return [] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 00dac1b7d7..5c4cac9719 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -356,12 +356,12 @@ class KnowledgeRetrievalNode(LLMNode): ) elif node_data.metadata_filtering_mode == "manual": if node_data.metadata_filtering_conditions: - metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) + conditions = [] if node_data.metadata_filtering_conditions: for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore metadata_name = condition.name expected_value = condition.value - if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): + if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): if isinstance(expected_value, str): expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value @@ -372,13 +372,24 @@ class KnowledgeRetrievalNode(LLMNode): expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore else: raise ValueError("Invalid expected metadata value type") - filters = self._process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, + conditions.append( + Condition( + name=metadata_name, + comparison_operator=condition.comparison_operator, + value=expected_value, ) + ) + filters = self._process_metadata_filter_func( + sequence, + condition.comparison_operator, + metadata_name, + expected_value, + filters, + ) + metadata_condition = MetadataCondition( + logical_operator=node_data.metadata_filtering_conditions.logical_operator, + conditions=conditions, + ) else: raise ValueError("Invalid metadata filtering mode") if filters: diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 0b98065f5d..56e06cc33e 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -69,6 +69,7 @@ class HitTestingService: query: str, account: Account, external_retrieval_model: dict, + metadata_filtering_conditions: dict, ) -> dict: if dataset.provider != "external": return { @@ -82,6 +83,7 @@ class HitTestingService: dataset_id=dataset.id, query=cls.escape_query_for_search(query), external_retrieval_model=external_retrieval_model, + metadata_filtering_conditions=metadata_filtering_conditions, ) end = time.perf_counter()