fix: metadata filtering condition variable unassigned; fix External K… (#19208)

This commit is contained in:
Will 2025-05-07 14:52:09 +08:00 committed by GitHub
parent d1c08a810b
commit bfa652f2d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 101 additions and 40 deletions

View File

@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json") parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
query=args["query"], query=args["query"],
account=current_user, account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
) )
return response return response

View File

@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner):
return_resource=app_config.additional_features.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback,
user_id=user_id,
inputs=cast(dict, application_generate_entity.inputs),
) )
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = (

View File

@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools self._prompt_messages_tools = prompt_messages_tools
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
function_call_state = True function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""

View File

@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
assert app_config.agent assert app_config.agent
iteration_step = 1 iteration_step = 1

View File

@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.rerank_type import RerankMode
@ -119,12 +120,25 @@ class RetrievalService:
return all_documents return all_documents
@classmethod @classmethod
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): def external_retrieve(
cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
return [] return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset_id, query, external_retrieval_model or {} dataset.tenant_id,
dataset_id,
query,
external_retrieval_model or {},
metadata_condition=metadata_condition,
) )
return all_documents return all_documents

View File

@ -149,7 +149,7 @@ class DatasetRetrieval:
else: else:
inputs = {} inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets] available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
available_datasets_ids, available_datasets_ids,
query, query,
tenant_id, tenant_id,
@ -649,6 +649,8 @@ class DatasetRetrieval:
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]: ) -> Optional[list[DatasetRetrieverBaseTool]]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
@ -706,6 +708,9 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback], hit_callbacks=[hit_callback],
return_resource=return_resource, return_resource=return_resource,
retriever_from=invoke_from.to_source(), retriever_from=invoke_from.to_source(),
retrieve_config=retrieve_config,
user_id=user_id,
inputs=inputs,
) )
tools.append(tool) tools.append(tool)
@ -826,7 +831,7 @@ class DatasetRetrieval:
) )
return filter_documents[:top_k] if top_k else filter_documents return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition( def get_metadata_filter_condition(
self, self,
dataset_ids: list, dataset_ids: list,
query: str, query: str,
@ -876,20 +881,31 @@ class DatasetRetrieval:
) )
elif metadata_filtering_mode == "manual": elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions: if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) conditions = []
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str): if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs) expected_value = self._replace_metadata_filter_value(expected_value, inputs)
filters = self._process_metadata_filter_func( conditions.append(
sequence, Condition(
condition.comparison_operator, name=metadata_name,
metadata_name, comparison_operator=condition.comparison_operator,
expected_value, value=expected_value,
filters,
) )
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else: else:
raise ValueError("Invalid metadata filtering mode") raise ValueError("Invalid metadata filtering mode")
if filters: if filters:

View File

@ -1,11 +1,12 @@
from typing import Any from typing import Any, Optional, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. " description: str = "use this to retrieve a dataset. "
dataset_id: str dataset_id: str
metadata_filtering_conditions: MetadataCondition user_id: Optional[str] = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
description=description, description=description,
metadata_filtering_conditions=MetadataCondition(),
**kwargs, **kwargs,
) )
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return "" return ""
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id) hit_callback.on_query(query, dataset.id)
dataset_retrieval = DatasetRetrieval()
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
[dataset.id],
query,
self.tenant_id,
self.user_id or "unknown",
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
else:
document_ids_filter = None
if dataset.provider == "external": if dataset.provider == "external":
results = [] results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
external_retrieval_parameters=dataset.retrieval_model, external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=self.metadata_filtering_conditions, metadata_condition=metadata_condition,
) )
for external_document in external_documents: for external_document in external_documents:
document = RetrievalDocument( document = RetrievalDocument(
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results])) return str("\n".join([item.page_content for item in results]))
else: else:
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
document_ids_filter=document_ids_filter,
) )
return str("\n".join([document.page_content for document in documents])) return str("\n".join([document.page_content for document in documents]))
else: else:
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"), weights=retrieval_model.get("weights"),
document_ids_filter=document_ids_filter,
) )
else: else:
documents = [] documents = []

View File

@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> list["DatasetRetrieverTool"]: ) -> list["DatasetRetrieverTool"]:
""" """
get dataset tool get dataset tool
@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
return_resource=return_resource, return_resource=return_resource,
invoke_from=invoke_from, invoke_from=invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback,
user_id=user_id,
inputs=inputs,
) )
if retrieval_tools is None or len(retrieval_tools) == 0: if retrieval_tools is None or len(retrieval_tools) == 0:
return [] return []

View File

@ -356,12 +356,12 @@ class KnowledgeRetrievalNode(LLMNode):
) )
elif node_data.metadata_filtering_mode == "manual": elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions: if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) conditions = []
if node_data.metadata_filtering_conditions: if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str): if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value expected_value
@ -372,13 +372,24 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
else: else:
raise ValueError("Invalid expected metadata value type") raise ValueError("Invalid expected metadata value type")
filters = self._process_metadata_filter_func( conditions.append(
sequence, Condition(
condition.comparison_operator, name=metadata_name,
metadata_name, comparison_operator=condition.comparison_operator,
expected_value, value=expected_value,
filters,
) )
)
filters = self._process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else: else:
raise ValueError("Invalid metadata filtering mode") raise ValueError("Invalid metadata filtering mode")
if filters: if filters:

View File

@ -69,6 +69,7 @@ class HitTestingService:
query: str, query: str,
account: Account, account: Account,
external_retrieval_model: dict, external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict: ) -> dict:
if dataset.provider != "external": if dataset.provider != "external":
return { return {
@ -82,6 +83,7 @@ class HitTestingService:
dataset_id=dataset.id, dataset_id=dataset.id,
query=cls.escape_query_for_search(query), query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model, external_retrieval_model=external_retrieval_model,
metadata_filtering_conditions=metadata_filtering_conditions,
) )
end = time.perf_counter() end = time.perf_counter()