mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 15:09:00 +08:00
fix: metadata filtering condition variable unassigned; fix External K… (#19208)
This commit is contained in:
parent
d1c08a810b
commit
bfa652f2d0
@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||
)
|
||||
|
||||
return response
|
||||
|
@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner):
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=cast(dict, application_generate_entity.inputs),
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = (
|
||||
|
@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
self._prompt_messages_tools = prompt_messages_tools
|
||||
|
||||
# fix metadata filter not work
|
||||
if app_config.dataset is not None:
|
||||
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
|
||||
for key, dataset_retriever_tool in tool_instances.items():
|
||||
if hasattr(dataset_retriever_tool, "retrieval_tool"):
|
||||
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
|
||||
|
||||
function_call_state = True
|
||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||
final_answer = ""
|
||||
|
@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
# fix metadata filter not work
|
||||
if app_config.dataset is not None:
|
||||
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
|
||||
for key, dataset_retriever_tool in tool_instances.items():
|
||||
if hasattr(dataset_retriever_tool, "retrieval_tool"):
|
||||
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
|
@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.embedding.retrieval import RetrievalSegments
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
@ -119,12 +120,25 @@ class RetrievalService:
|
||||
return all_documents
|
||||
|
||||
@classmethod
|
||||
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
||||
def external_retrieve(
|
||||
cls,
|
||||
dataset_id: str,
|
||||
query: str,
|
||||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||
)
|
||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
|
||||
dataset.tenant_id,
|
||||
dataset_id,
|
||||
query,
|
||||
external_retrieval_model or {},
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
return all_documents
|
||||
|
||||
|
@ -149,7 +149,7 @@ class DatasetRetrieval:
|
||||
else:
|
||||
inputs = {}
|
||||
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
available_datasets_ids,
|
||||
query,
|
||||
tenant_id,
|
||||
@ -649,6 +649,8 @@ class DatasetRetrieval:
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
@ -706,6 +708,9 @@ class DatasetRetrieval:
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
retriever_from=invoke_from.to_source(),
|
||||
retrieve_config=retrieve_config,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
@ -826,7 +831,7 @@ class DatasetRetrieval:
|
||||
)
|
||||
return filter_documents[:top_k] if top_k else filter_documents
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
def get_metadata_filter_condition(
|
||||
self,
|
||||
dataset_ids: list,
|
||||
query: str,
|
||||
@ -876,20 +881,31 @@ class DatasetRetrieval:
|
||||
)
|
||||
elif metadata_filtering_mode == "manual":
|
||||
if metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.metadata_entities import MetadataCondition
|
||||
from core.rag.models.document import Document as RetrievalDocument
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
dataset_id: str
|
||||
metadata_filtering_conditions: MetadataCondition
|
||||
user_id: Optional[str] = None
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
inputs: dict
|
||||
|
||||
@classmethod
|
||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
description=description,
|
||||
metadata_filtering_conditions=MetadataCondition(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return ""
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.on_query(query, dataset.id)
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||
[dataset.id],
|
||||
query,
|
||||
self.tenant_id,
|
||||
self.user_id or "unknown",
|
||||
cast(str, self.retrieve_config.metadata_filtering_mode),
|
||||
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
||||
self.retrieve_config.metadata_filtering_conditions,
|
||||
self.inputs,
|
||||
)
|
||||
if metadata_filter_document_ids:
|
||||
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||
else:
|
||||
document_ids_filter = None
|
||||
if dataset.provider == "external":
|
||||
results = []
|
||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
external_retrieval_parameters=dataset.retrieval_model,
|
||||
metadata_condition=self.metadata_filtering_conditions,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
for external_document in external_documents:
|
||||
document = RetrievalDocument(
|
||||
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
if metadata_condition and not document_ids_filter:
|
||||
return ""
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
||||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=self.top_k,
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights"),
|
||||
document_ids_filter=document_ids_filter,
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
|
@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
user_id: str,
|
||||
inputs: dict,
|
||||
) -> list["DatasetRetrieverTool"]:
|
||||
"""
|
||||
get dataset tool
|
||||
@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
|
||||
return_resource=return_resource,
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||
return []
|
||||
|
@ -356,12 +356,12 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
)
|
||||
elif node_data.metadata_filtering_mode == "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
||||
conditions = []
|
||||
if node_data.metadata_filtering_conditions:
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
@ -372,13 +372,24 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = self._process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
|
@ -69,6 +69,7 @@ class HitTestingService:
|
||||
query: str,
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
metadata_filtering_conditions: dict,
|
||||
) -> dict:
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
@ -82,6 +83,7 @@ class HitTestingService:
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
external_retrieval_model=external_retrieval_model,
|
||||
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
|
Loading…
x
Reference in New Issue
Block a user