mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 01:15:56 +08:00
fix: metadata filtering condition variable unassigned; fix External K… (#19208)
This commit is contained in:
parent
d1c08a810b
commit
bfa652f2d0
@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("query", type=str, location="json")
|
parser.add_argument("query", type=str, location="json")
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
HitTestingService.hit_testing_args_check(args)
|
||||||
@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
|||||||
query=args["query"],
|
query=args["query"],
|
||||||
account=current_user,
|
account=current_user,
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
external_retrieval_model=args["external_retrieval_model"],
|
||||||
|
metadata_filtering_conditions=args["metadata_filtering_conditions"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner):
|
|||||||
return_resource=app_config.additional_features.show_retrieve_source,
|
return_resource=app_config.additional_features.show_retrieve_source,
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
|
user_id=user_id,
|
||||||
|
inputs=cast(dict, application_generate_entity.inputs),
|
||||||
)
|
)
|
||||||
# get how many agent thoughts have been created
|
# get how many agent thoughts have been created
|
||||||
self.agent_thought_count = (
|
self.agent_thought_count = (
|
||||||
|
@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
|||||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||||
self._prompt_messages_tools = prompt_messages_tools
|
self._prompt_messages_tools = prompt_messages_tools
|
||||||
|
|
||||||
# fix metadata filter not work
|
|
||||||
if app_config.dataset is not None:
|
|
||||||
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
|
|
||||||
for key, dataset_retriever_tool in tool_instances.items():
|
|
||||||
if hasattr(dataset_retriever_tool, "retrieval_tool"):
|
|
||||||
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
|
|
||||||
|
|
||||||
function_call_state = True
|
function_call_state = True
|
||||||
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
|
||||||
final_answer = ""
|
final_answer = ""
|
||||||
|
@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
# convert tools into ModelRuntime Tool format
|
# convert tools into ModelRuntime Tool format
|
||||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||||
|
|
||||||
# fix metadata filter not work
|
|
||||||
if app_config.dataset is not None:
|
|
||||||
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
|
|
||||||
for key, dataset_retriever_tool in tool_instances.items():
|
|
||||||
if hasattr(dataset_retriever_tool, "retrieval_tool"):
|
|
||||||
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
|
|
||||||
|
|
||||||
assert app_config.agent
|
assert app_config.agent
|
||||||
|
|
||||||
iteration_step = 1
|
iteration_step = 1
|
||||||
|
@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
|||||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.embedding.retrieval import RetrievalSegments
|
from core.rag.embedding.retrieval import RetrievalSegments
|
||||||
|
from core.rag.entities.metadata_entities import MetadataCondition
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from core.rag.rerank.rerank_type import RerankMode
|
from core.rag.rerank.rerank_type import RerankMode
|
||||||
@ -119,12 +120,25 @@ class RetrievalService:
|
|||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
|
def external_retrieve(
|
||||||
|
cls,
|
||||||
|
dataset_id: str,
|
||||||
|
query: str,
|
||||||
|
external_retrieval_model: Optional[dict] = None,
|
||||||
|
metadata_filtering_conditions: Optional[dict] = None,
|
||||||
|
):
|
||||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||||
if not dataset:
|
if not dataset:
|
||||||
return []
|
return []
|
||||||
|
metadata_condition = (
|
||||||
|
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
|
||||||
|
)
|
||||||
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
|
dataset.tenant_id,
|
||||||
|
dataset_id,
|
||||||
|
query,
|
||||||
|
external_retrieval_model or {},
|
||||||
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
return all_documents
|
return all_documents
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ class DatasetRetrieval:
|
|||||||
else:
|
else:
|
||||||
inputs = {}
|
inputs = {}
|
||||||
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
available_datasets_ids = [dataset.id for dataset in available_datasets]
|
||||||
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
|
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||||
available_datasets_ids,
|
available_datasets_ids,
|
||||||
query,
|
query,
|
||||||
tenant_id,
|
tenant_id,
|
||||||
@ -649,6 +649,8 @@ class DatasetRetrieval:
|
|||||||
return_resource: bool,
|
return_resource: bool,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
hit_callback: DatasetIndexToolCallbackHandler,
|
hit_callback: DatasetIndexToolCallbackHandler,
|
||||||
|
user_id: str,
|
||||||
|
inputs: dict,
|
||||||
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
) -> Optional[list[DatasetRetrieverBaseTool]]:
|
||||||
"""
|
"""
|
||||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||||
@ -706,6 +708,9 @@ class DatasetRetrieval:
|
|||||||
hit_callbacks=[hit_callback],
|
hit_callbacks=[hit_callback],
|
||||||
return_resource=return_resource,
|
return_resource=return_resource,
|
||||||
retriever_from=invoke_from.to_source(),
|
retriever_from=invoke_from.to_source(),
|
||||||
|
retrieve_config=retrieve_config,
|
||||||
|
user_id=user_id,
|
||||||
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
tools.append(tool)
|
tools.append(tool)
|
||||||
@ -826,7 +831,7 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
return filter_documents[:top_k] if top_k else filter_documents
|
return filter_documents[:top_k] if top_k else filter_documents
|
||||||
|
|
||||||
def _get_metadata_filter_condition(
|
def get_metadata_filter_condition(
|
||||||
self,
|
self,
|
||||||
dataset_ids: list,
|
dataset_ids: list,
|
||||||
query: str,
|
query: str,
|
||||||
@ -876,20 +881,31 @@ class DatasetRetrieval:
|
|||||||
)
|
)
|
||||||
elif metadata_filtering_mode == "manual":
|
elif metadata_filtering_mode == "manual":
|
||||||
if metadata_filtering_conditions:
|
if metadata_filtering_conditions:
|
||||||
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
|
conditions = []
|
||||||
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
|
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
|
||||||
metadata_name = condition.name
|
metadata_name = condition.name
|
||||||
expected_value = condition.value
|
expected_value = condition.value
|
||||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||||
if isinstance(expected_value, str):
|
if isinstance(expected_value, str):
|
||||||
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
|
||||||
filters = self._process_metadata_filter_func(
|
conditions.append(
|
||||||
sequence,
|
Condition(
|
||||||
condition.comparison_operator,
|
name=metadata_name,
|
||||||
metadata_name,
|
comparison_operator=condition.comparison_operator,
|
||||||
expected_value,
|
value=expected_value,
|
||||||
filters,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
filters = self._process_metadata_filter_func(
|
||||||
|
sequence,
|
||||||
|
condition.comparison_operator,
|
||||||
|
metadata_name,
|
||||||
|
expected_value,
|
||||||
|
filters,
|
||||||
|
)
|
||||||
|
metadata_condition = MetadataCondition(
|
||||||
|
logical_operator=metadata_filtering_conditions.logical_operator,
|
||||||
|
conditions=conditions,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid metadata filtering mode")
|
raise ValueError("Invalid metadata filtering mode")
|
||||||
if filters:
|
if filters:
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from typing import Any
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
from core.rag.entities.metadata_entities import MetadataCondition
|
|
||||||
from core.rag.models.document import Document as RetrievalDocument
|
from core.rag.models.document import Document as RetrievalDocument
|
||||||
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
args_schema: type[BaseModel] = DatasetRetrieverToolInput
|
||||||
description: str = "use this to retrieve a dataset. "
|
description: str = "use this to retrieve a dataset. "
|
||||||
dataset_id: str
|
dataset_id: str
|
||||||
metadata_filtering_conditions: MetadataCondition
|
user_id: Optional[str] = None
|
||||||
|
retrieve_config: DatasetRetrieveConfigEntity
|
||||||
|
inputs: dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dataset(cls, dataset: Dataset, **kwargs):
|
def from_dataset(cls, dataset: Dataset, **kwargs):
|
||||||
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
description=description,
|
description=description,
|
||||||
metadata_filtering_conditions=MetadataCondition(),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
return ""
|
return ""
|
||||||
for hit_callback in self.hit_callbacks:
|
for hit_callback in self.hit_callbacks:
|
||||||
hit_callback.on_query(query, dataset.id)
|
hit_callback.on_query(query, dataset.id)
|
||||||
|
dataset_retrieval = DatasetRetrieval()
|
||||||
|
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
|
||||||
|
[dataset.id],
|
||||||
|
query,
|
||||||
|
self.tenant_id,
|
||||||
|
self.user_id or "unknown",
|
||||||
|
cast(str, self.retrieve_config.metadata_filtering_mode),
|
||||||
|
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
||||||
|
self.retrieve_config.metadata_filtering_conditions,
|
||||||
|
self.inputs,
|
||||||
|
)
|
||||||
|
if metadata_filter_document_ids:
|
||||||
|
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
|
||||||
|
else:
|
||||||
|
document_ids_filter = None
|
||||||
if dataset.provider == "external":
|
if dataset.provider == "external":
|
||||||
results = []
|
results = []
|
||||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=query,
|
query=query,
|
||||||
external_retrieval_parameters=dataset.retrieval_model,
|
external_retrieval_parameters=dataset.retrieval_model,
|
||||||
metadata_condition=self.metadata_filtering_conditions,
|
metadata_condition=metadata_condition,
|
||||||
)
|
)
|
||||||
for external_document in external_documents:
|
for external_document in external_documents:
|
||||||
document = RetrievalDocument(
|
document = RetrievalDocument(
|
||||||
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
|
|
||||||
return str("\n".join([item.page_content for item in results]))
|
return str("\n".join([item.page_content for item in results]))
|
||||||
else:
|
else:
|
||||||
|
if metadata_condition and not document_ids_filter:
|
||||||
|
return ""
|
||||||
# get retrieval model , if the model is not setting , using default
|
# get retrieval model , if the model is not setting , using default
|
||||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
# use keyword table query
|
# use keyword table query
|
||||||
documents = RetrievalService.retrieve(
|
documents = RetrievalService.retrieve(
|
||||||
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
|
retrieval_method="keyword_search",
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
query=query,
|
||||||
|
top_k=self.top_k,
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
return str("\n".join([document.page_content for document in documents]))
|
return str("\n".join([document.page_content for document in documents]))
|
||||||
else:
|
else:
|
||||||
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
else None,
|
else None,
|
||||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||||
weights=retrieval_model.get("weights"),
|
weights=retrieval_model.get("weights"),
|
||||||
|
document_ids_filter=document_ids_filter,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
documents = []
|
documents = []
|
||||||
|
@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
|
|||||||
return_resource: bool,
|
return_resource: bool,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
hit_callback: DatasetIndexToolCallbackHandler,
|
hit_callback: DatasetIndexToolCallbackHandler,
|
||||||
|
user_id: str,
|
||||||
|
inputs: dict,
|
||||||
) -> list["DatasetRetrieverTool"]:
|
) -> list["DatasetRetrieverTool"]:
|
||||||
"""
|
"""
|
||||||
get dataset tool
|
get dataset tool
|
||||||
@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
|
|||||||
return_resource=return_resource,
|
return_resource=return_resource,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
hit_callback=hit_callback,
|
hit_callback=hit_callback,
|
||||||
|
user_id=user_id,
|
||||||
|
inputs=inputs,
|
||||||
)
|
)
|
||||||
if retrieval_tools is None or len(retrieval_tools) == 0:
|
if retrieval_tools is None or len(retrieval_tools) == 0:
|
||||||
return []
|
return []
|
||||||
|
@ -356,12 +356,12 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
)
|
)
|
||||||
elif node_data.metadata_filtering_mode == "manual":
|
elif node_data.metadata_filtering_mode == "manual":
|
||||||
if node_data.metadata_filtering_conditions:
|
if node_data.metadata_filtering_conditions:
|
||||||
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
|
conditions = []
|
||||||
if node_data.metadata_filtering_conditions:
|
if node_data.metadata_filtering_conditions:
|
||||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||||
metadata_name = condition.name
|
metadata_name = condition.name
|
||||||
expected_value = condition.value
|
expected_value = condition.value
|
||||||
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
|
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||||
if isinstance(expected_value, str):
|
if isinstance(expected_value, str):
|
||||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||||
expected_value
|
expected_value
|
||||||
@ -372,13 +372,24 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid expected metadata value type")
|
raise ValueError("Invalid expected metadata value type")
|
||||||
filters = self._process_metadata_filter_func(
|
conditions.append(
|
||||||
sequence,
|
Condition(
|
||||||
condition.comparison_operator,
|
name=metadata_name,
|
||||||
metadata_name,
|
comparison_operator=condition.comparison_operator,
|
||||||
expected_value,
|
value=expected_value,
|
||||||
filters,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
filters = self._process_metadata_filter_func(
|
||||||
|
sequence,
|
||||||
|
condition.comparison_operator,
|
||||||
|
metadata_name,
|
||||||
|
expected_value,
|
||||||
|
filters,
|
||||||
|
)
|
||||||
|
metadata_condition = MetadataCondition(
|
||||||
|
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||||
|
conditions=conditions,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid metadata filtering mode")
|
raise ValueError("Invalid metadata filtering mode")
|
||||||
if filters:
|
if filters:
|
||||||
|
@ -69,6 +69,7 @@ class HitTestingService:
|
|||||||
query: str,
|
query: str,
|
||||||
account: Account,
|
account: Account,
|
||||||
external_retrieval_model: dict,
|
external_retrieval_model: dict,
|
||||||
|
metadata_filtering_conditions: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if dataset.provider != "external":
|
if dataset.provider != "external":
|
||||||
return {
|
return {
|
||||||
@ -82,6 +83,7 @@ class HitTestingService:
|
|||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
query=cls.escape_query_for_search(query),
|
query=cls.escape_query_for_search(query),
|
||||||
external_retrieval_model=external_retrieval_model,
|
external_retrieval_model=external_retrieval_model,
|
||||||
|
metadata_filtering_conditions=metadata_filtering_conditions,
|
||||||
)
|
)
|
||||||
|
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user