From a6ea15e63cd9d9520f94cb085f8dc6fc28ca6297 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 30 May 2025 14:36:44 +0800 Subject: [PATCH] Refactor/message cycle manage and knowledge retrieval (#20460) Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 44 +++++------ .../apps/workflow/generate_task_pipeline.py | 4 - api/core/app/entities/queue_entities.py | 5 +- api/core/app/entities/task_entities.py | 23 +++++- .../easy_ui_based_generate_task_pipeline.py | 45 ++++++------ ...cle_manage.py => message_cycle_manager.py} | 29 +++++--- .../index_tool_callback_handler.py | 5 +- api/core/rag/entities/citation_metadata.py | 23 ++++++ api/core/rag/retrieval/dataset_retrieval.py | 63 ++++++++-------- .../dataset_multi_retriever_tool.py | 39 +++++----- .../dataset_retriever_tool.py | 73 ++++++++++--------- .../workflow/graph_engine/entities/event.py | 5 +- api/core/workflow/nodes/event/event.py | 4 +- api/core/workflow/nodes/llm/node.py | 41 ++++++----- 14 files changed, 222 insertions(+), 181 deletions(-) rename api/core/app/task_pipeline/{message_cycle_manage.py => message_cycle_manager.py} (85%) create mode 100644 api/core/rag/entities/citation_metadata.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 112e178b82..8c5645bbb7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator, Mapping @@ -57,10 +56,9 @@ from core.app.entities.task_entities import ( WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.enums import SystemVariableKey @@ -141,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) self._task_state = WorkflowTaskState() - self._message_cycle_manager = MessageCycleManage( + self._message_cycle_manager = MessageCycleManager( application_generate_entity=application_generate_entity, task_state=self._task_state ) @@ -162,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline: :return: """ # start generate conversation name thread - self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) @@ -605,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): - self._message_cycle_manager._handle_retriever_resources(event) + self._message_cycle_manager.handle_retriever_resources(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() session.commit() elif isinstance(event, QueueAnnotationReplyEvent): - self._message_cycle_manager._handle_annotation_reply(event) + self._message_cycle_manager.handle_annotation_reply(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text @@ -637,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline: tts_publisher.publish(queue_message) self._task_state.answer += delta_text - yield self._message_cycle_manager._message_to_stream_response( + yield self._message_cycle_manager.message_to_stream_response( answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_cycle_manager._message_replace_to_stream_response( + yield self._message_cycle_manager.message_replace_to_stream_response( answer=event.text, reason=event.reason ) elif isinstance(event, QueueAdvancedChatMessageEndEvent): @@ -654,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) if output_moderation_answer: self._task_state.answer = output_moderation_answer - yield self._message_cycle_manager._message_replace_to_stream_response( + yield self._message_cycle_manager.message_replace_to_stream_response( answer=output_moderation_answer, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) @@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: message = self._get_message(session=session) message.answer = self._task_state.answer message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( message_id=message.id, @@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer_price_unit = usage.completion_price_unit message.total_price = usage.total_price message.currency = usage.currency - self._task_state.metadata["usage"] = jsonable_encoder(usage) + self._task_state.metadata.usage = usage else: - self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) + self._task_state.metadata.usage = LLMUsage.empty_usage() message_was_created.send( message, application_generate_entity=self._application_generate_entity, @@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline: Message end to stream response. :return: """ - extras = {} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata.copy() + extras = self._task_state.metadata.model_dump() - if "annotation_reply" in extras["metadata"]: - del extras["metadata"]["annotation_reply"] + if self._task_state.metadata.annotation_reply: + del extras["annotation_reply"] return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, files=self._recorded_files, - metadata=extras.get("metadata", {}), + metadata=extras, ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index fb0ea80a3a..1734dbb598 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -50,7 +50,6 @@ from core.app.entities.task_entities import ( WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk @@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline: ) self._application_generate_entity = application_generate_entity - self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._task_state = WorkflowTaskState() self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(queue_message) - self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( delta_text, from_variable_selector=event.from_variable_selector ) diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bc8573013a..42e6a1519c 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional @@ -6,6 +6,7 @@ from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState @@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES - retriever_resources: list[dict] + retriever_resources: Sequence[RetrievalSourceMetadata] in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" in_loop_id: Optional[str] = None diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 70748b085d..25c889e922 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +class AnnotationReplyAccount(BaseModel): + id: str + name: str + + +class AnnotationReply(BaseModel): + id: str + account: AnnotationReplyAccount + + +class TaskStateMetadata(BaseModel): + annotation_reply: AnnotationReply | None = None + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) + usage: LLMUsage | None = None + + class TaskState(BaseModel): """ TaskState entity """ - metadata: dict = {} + metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata) class EasyUITaskState(TaskState): diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6c768fd86c..1ea50a5778 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator @@ -43,7 +42,7 @@ from core.app.entities.task_entities import ( StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) -class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): +class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) ) + self._message_cycle_manager = MessageCycleManager( + application_generate_entity=application_generate_entity, + task_state=self._task_state, + ) + self._conversation_name_generate_thread: Optional[Thread] = None def process( @@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ]: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) @@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} + extras = {"usage": self._task_state.llm_result.usage.model_dump()} if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=output_moderation_answer + ) with Session(db.engine) as session: # Save message @@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_end_resp = self._message_end_to_stream_response() yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) + self._message_cycle_manager.handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): - annotation = self._handle_annotation_reply(event) + annotation = self._message_cycle_manager.handle_annotation_reply(event) if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): @@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if agent_thought_response is not None: yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): - response = self._message_file_to_stream_response(event) + response = self._message_cycle_manager.message_file_to_stream_response(event) if response: yield response elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): @@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response( + yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, ) @@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message_id, ) elif isinstance(event, QueueMessageReplaceEvent): - yield self._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): yield self._ping_stream_response() else: @@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message.provider_response_latency = time.perf_counter() - self._start_at message.total_price = usage.total_price message.currency = usage.currency - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: trace_manager.add_trace_task( @@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan Message end to stream response. :return: """ - self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) - - extras = {} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata - + self._task_state.metadata.usage = self._task_state.llm_result.usage + metadata_dict = self._task_state.metadata.model_dump() return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, - metadata=extras.get("metadata", {}), + metadata=metadata_dict, ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manager.py similarity index 85% rename from api/core/app/task_pipeline/message_cycle_manage.py rename to api/core/app/task_pipeline/message_cycle_manager.py index a6d826f08b..2343081eaf 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -17,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( + AnnotationReply, + AnnotationReplyAccount, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, @@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService -class MessageCycleManage: +class MessageCycleManager: def __init__( self, *, @@ -45,7 +47,7 @@ class MessageCycleManage: self._application_generate_entity = application_generate_entity self._task_state = task_state - def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation_id: conversation id @@ -102,7 +104,7 @@ class MessageCycleManage: db.session.commit() db.session.close() - def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: """ Handle annotation reply. :param event: event @@ -111,25 +113,28 @@ class MessageCycleManage: annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata["annotation_reply"] = { - "id": annotation.id, - "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, - } + self._task_state.metadata.annotation_reply = AnnotationReply( + id=annotation.id, + account=AnnotationReplyAccount( + id=annotation.account_id, + name=account.name if account else "Dify user", + ), + ) return annotation return None - def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: + def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: """ Handle retriever resources. :param event: event :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata["retriever_resources"] = event.retriever_resources + self._task_state.metadata.retriever_resources = event.retriever_resources - def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ Message file to stream response. :param event: event @@ -166,7 +171,7 @@ class MessageCycleManage: return None - def _message_to_stream_response( + def message_to_stream_response( self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None ) -> MessageStreamResponse: """ @@ -182,7 +187,7 @@ class MessageCycleManage: from_variable_selector=from_variable_selector, ) - def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: + def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: """ Message replace to stream response. :param answer: answer diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 13c22213c4..a3a7b4b812 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,8 +1,10 @@ import logging +from collections.abc import Sequence from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from extensions.ext_database import db @@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler: db.session.commit() - def return_retriever_resource_info(self, resource: list): + # TODO(-LAN-): Improve type check + def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): """Handle return_retriever_resource_info.""" self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py new file mode 100644 index 0000000000..00120425c9 --- /dev/null +++ b/api/core/rag/entities/citation_metadata.py @@ -0,0 +1,23 @@ +from typing import Any, Optional + +from pydantic import BaseModel + + +class RetrievalSourceMetadata(BaseModel): + position: Optional[int] = None + dataset_id: Optional[str] = None + dataset_name: Optional[str] = None + document_id: Optional[str] = None + document_name: Optional[str] = None + data_source_type: Optional[str] = None + segment_id: Optional[str] = None + retriever_from: Optional[str] = None + score: Optional[float] = None + hit_count: Optional[int] = None + word_count: Optional[int] = None + segment_position: Optional[int] = None + index_node_hash: Optional[str] = None + content: Optional[str] = None + page: Optional[int] = None + doc_metadata: Optional[dict[str, Any]] = None + title: Optional[str] = None diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c4adf6de4d..f668148428 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.index_processor.constant.index_type import IndexType @@ -198,21 +199,21 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] - document_context_list = [] - retrieval_resource_list = [] + document_context_list: list[DocumentContext] = [] + retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) - source = { - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": invoke_from.to_source(), - "score": item.metadata.get("score"), - "content": item.page_content, - } + source = RetrievalSourceMetadata( + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id") or item.metadata.get("title"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from=invoke_from.to_source(), + score=item.metadata.get("score"), + content=item.page_content, + ) retrieval_resource_list.append(source) # deal with dify documents if dify_documents: @@ -248,32 +249,32 @@ class DatasetRetrieval: .first() ) if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": invoke_from.to_source(), - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, - } + source = RetrievalSourceMetadata( + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from=invoke_from.to_source(), + score=record.score or 0.0, + doc_metadata=document.doc_metadata, + ) if invoke_from.to_source() == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: - retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) + retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) for position, item in enumerate(retrieval_resource_list, start=1): - item["position"] = position + item.position = position hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 04437ea6d8..93d3fcc49d 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): else: document_context_list.append(segment.get_sign_content()) if self.return_resource: - context_list = [] + context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() @@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): .first() ) if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - "doc_metadata": document.doc_metadata, - } + source = RetrievalSourceMetadata( + position=resource_number, + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from=self.retriever_from, + score=document_score_list.get(segment.index_node_id, None), + doc_metadata=document.doc_metadata, + ) if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content context_list.append(source) resource_number += 1 diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 7b6882ed52..ff1d9021ce 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -14,7 +15,7 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): else: document_ids_filter = None if dataset.provider == "external": - results = [] + results: list[RetrievalDocument] = [] external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document.metadata["dataset_name"] = dataset.name results.append(document) # deal with external documents - context_list = [] + context_list: list[RetrievalSourceMetadata] = [] for position, item in enumerate(results, start=1): if item.metadata is not None: - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + source = RetrievalSourceMetadata( + position=position, + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id") or item.metadata.get("title"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from=self.retriever_from, + score=item.metadata.get("score"), + title=item.metadata.get("title"), + content=item.page_content, + ) context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return "" # get retrieval model , if the model is not setting , using default retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model - retrieval_resource_list = [] + retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] + document_context_list: list[DocumentContext] = [] records = RetrievalService.format_retrieval_documents(documents) if records: for record in records: @@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): .first() ) if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, # type: ignore - "document_name": document.name, # type: ignore - "data_source_type": document.data_source_type, # type: ignore - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, # type: ignore - } + source = RetrievalSourceMetadata( + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, # type: ignore + document_name=document.name, # type: ignore + data_source_type=document.data_source_type, # type: ignore + segment_id=segment.id, + retriever_from=self.retriever_from, + score=record.score or 0.0, + doc_metadata=document.doc_metadata, # type: ignore + ) if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x.get("score") or 0.0, + key=lambda x: x.score or 0.0, reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore - item["position"] = position # type: ignore + item.position = position # type: ignore for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 689a07c4f6..9a4939502e 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,9 +1,10 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Optional from pydantic import BaseModel, Field +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType @@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent): - retriever_resources: list[dict] = Field(..., description="retriever resources") + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index f45919caf5..b72d111f49 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,8 +1,10 @@ +from collections.abc import Sequence from datetime import datetime from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel): class RunRetrieverResourceEvent(BaseModel): - retriever_resources: list[dict] = Field(..., description="retriever resources") + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 38e4f7af01..0fd7c31ffb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.variables import ( ArrayAnySegment, ArrayFileSegment, @@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]): yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) elif isinstance(context_value_variable, ArraySegment): context_str = "" - original_retriever_resource = [] + original_retriever_resource: list[RetrievalSourceMetadata] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]): retriever_resources=original_retriever_resource, context=context_str.strip() ) - def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: + def _convert_to_original_retriever_resource(self, context_dict: dict): if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]): ): metadata = context_dict.get("metadata", {}) - source = { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - } + source = RetrievalSourceMetadata( + position=metadata.get("position"), + dataset_id=metadata.get("dataset_id"), + dataset_name=metadata.get("dataset_name"), + document_id=metadata.get("document_id"), + document_name=metadata.get("document_name"), + data_source_type=metadata.get("data_source_type"), + segment_id=metadata.get("segment_id"), + retriever_from=metadata.get("retriever_from"), + score=metadata.get("score"), + hit_count=metadata.get("segment_hit_count"), + word_count=metadata.get("segment_word_count"), + segment_position=metadata.get("segment_position"), + index_node_hash=metadata.get("segment_index_node_hash"), + content=context_dict.get("content"), + page=metadata.get("page"), + doc_metadata=metadata.get("doc_metadata"), + ) return source