Refactor/message cycle manage and knowledge retrieval (#20460)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-30 14:36:44 +08:00 committed by GitHub
parent 5a991295e0
commit a6ea15e63c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 222 additions and 181 deletions

View File

@ -1,4 +1,3 @@
import json
import logging import logging
import time import time
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
@ -57,10 +56,9 @@ from core.app.entities.task_entities import (
WorkflowTaskState, WorkflowTaskState,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -141,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
) )
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage( self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state application_generate_entity=application_generate_entity, task_state=self._task_state
) )
@ -162,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
:return: :return:
""" """
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )
@ -605,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager._handle_retriever_resources(event) self._message_cycle_manager.handle_retriever_resources(event)
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._message_cycle_manager._handle_annotation_reply(event) self._message_cycle_manager.handle_annotation_reply(event)
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit() session.commit()
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
@ -637,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_cycle_manager._message_to_stream_response( yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response( yield self._message_cycle_manager.message_replace_to_stream_response(
answer=event.text, reason=event.reason answer=event.text, reason=event.reason
) )
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
@ -654,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
) )
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response( yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer, answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
) )
@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=message.id, message_id=message.id,
@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage) self._task_state.metadata.usage = usage
else: else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send( message_was_created.send(
message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
Message end to stream response. Message end to stream response.
:return: :return:
""" """
extras = {} extras = self._task_state.metadata.model_dump()
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
if "annotation_reply" in extras["metadata"]: if self._task_state.metadata.annotation_reply:
del extras["metadata"]["annotation_reply"] del extras["annotation_reply"]
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message_id, id=self._message_id,
files=self._recorded_files, files=self._recorded_files,
metadata=extras.get("metadata", {}), metadata=extras,
) )
def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@ -50,7 +50,6 @@ from core.app.entities.task_entities import (
WorkflowAppStreamResponse, WorkflowAppStreamResponse,
WorkflowFinishStreamResponse, WorkflowFinishStreamResponse,
WorkflowStartStreamResponse, WorkflowStartStreamResponse,
WorkflowTaskState,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
) )
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = "" self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher: if tts_publisher:
tts_publisher.publish(queue_message) tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response( yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
@ -6,6 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
""" """
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict] retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: Optional[str] = None in_loop_id: Optional[str] = None

View File

@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
id: str
name: str
class AnnotationReply(BaseModel):
id: str
account: AnnotationReplyAccount
class TaskStateMetadata(BaseModel):
annotation_reply: AnnotationReply | None = None
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
usage: LLMUsage | None = None
class TaskState(BaseModel): class TaskState(BaseModel):
""" """
TaskState entity TaskState entity
""" """
metadata: dict = {} metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
class EasyUITaskState(TaskState): class EasyUITaskState(TaskState):

View File

@ -1,4 +1,3 @@
import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
StreamResponse, StreamResponse,
) )
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
) )
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
""" """
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
) )
) )
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity,
task_state=self._task_state,
)
self._conversation_name_generate_thread: Optional[Thread] = None self._conversation_name_generate_thread: Optional[Thread] = None
def process( def process(
@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
]: ]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
) )
@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse): if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value: if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
) )
if output_moderation_answer: if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)
with Session(db.engine) as session: with Session(db.engine) as session:
# Save message # Save message
@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_end_resp = self._message_end_to_stream_response() message_end_resp = self._message_end_to_stream_response()
yield message_end_resp yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event) annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation: if annotation:
self._task_state.llm_result.message.content = annotation.content self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent): elif isinstance(event, QueueAgentThoughtEvent):
@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if agent_thought_response is not None: if agent_thought_response is not None:
yield agent_thought_response yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent): elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event) response = self._message_cycle_manager.message_file_to_stream_response(event)
if response: if response:
yield response yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = current_content self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent): if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response( yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text), answer=cast(str, delta_text),
message_id=self._message_id, message_id=self._message_id,
) )
@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message_id, message_id=self._message_id,
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent): elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._ping_stream_response()
else: else:
@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response. Message end to stream response.
:return: :return:
""" """
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message_id, id=self._message_id,
metadata=extras.get("metadata", {}), metadata=metadata_dict,
) )
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:

View File

@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState, EasyUITaskState,
MessageFileStreamResponse, MessageFileStreamResponse,
MessageReplaceStreamResponse, MessageReplaceStreamResponse,
@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
class MessageCycleManage: class MessageCycleManager:
def __init__( def __init__(
self, self,
*, *,
@ -45,7 +47,7 @@ class MessageCycleManage:
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._task_state = task_state self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """
Generate conversation name. Generate conversation name.
:param conversation_id: conversation id :param conversation_id: conversation id
@ -102,7 +104,7 @@ class MessageCycleManage:
db.session.commit() db.session.commit()
db.session.close() db.session.close()
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
""" """
Handle annotation reply. Handle annotation reply.
:param event: event :param event: event
@ -111,25 +113,28 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation: if annotation:
account = annotation.account account = annotation.account
self._task_state.metadata["annotation_reply"] = { self._task_state.metadata.annotation_reply = AnnotationReply(
"id": annotation.id, id=annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, account=AnnotationReplyAccount(
} id=annotation.account_id,
name=account.name if account else "Dify user",
),
)
return annotation return annotation
return None return None
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
""" """
Handle retriever resources. Handle retriever resources.
:param event: event :param event: event
:return: :return:
""" """
if self._application_generate_entity.app_config.additional_features.show_retrieve_source: if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources self._task_state.metadata.retriever_resources = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
""" """
Message file to stream response. Message file to stream response.
:param event: event :param event: event
@ -166,7 +171,7 @@ class MessageCycleManage:
return None return None
def _message_to_stream_response( def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse: ) -> MessageStreamResponse:
""" """
@ -182,7 +187,7 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector, from_variable_selector=from_variable_selector,
) )
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
""" """
Message replace to stream response. Message replace to stream response.
:param answer: answer :param answer: answer

View File

@ -1,8 +1,10 @@
import logging import logging
from collections.abc import Sequence
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
db.session.commit() db.session.commit()
def return_retriever_resource_info(self, resource: list): # TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER

View File

@ -0,0 +1,23 @@
from typing import Any, Optional
from pydantic import BaseModel
class RetrievalSourceMetadata(BaseModel):
position: Optional[int] = None
dataset_id: Optional[str] = None
dataset_name: Optional[str] = None
document_id: Optional[str] = None
document_name: Optional[str] = None
data_source_type: Optional[str] = None
segment_id: Optional[str] = None
retriever_from: Optional[str] = None
score: Optional[float] = None
hit_count: Optional[int] = None
word_count: Optional[int] = None
segment_position: Optional[int] = None
index_node_hash: Optional[str] = None
content: Optional[str] = None
page: Optional[int] = None
doc_metadata: Optional[dict[str, Any]] = None
title: Optional[str] = None

View File

@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
@ -198,21 +199,21 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list = [] document_context_list: list[DocumentContext] = []
retrieval_resource_list = [] retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents # deal with external documents
for item in external_documents: for item in external_documents:
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
source = { source = RetrievalSourceMetadata(
"dataset_id": item.metadata.get("dataset_id"), dataset_id=item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"), dataset_name=item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"), document_id=item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"), document_name=item.metadata.get("title"),
"data_source_type": "external", data_source_type="external",
"retriever_from": invoke_from.to_source(), retriever_from=invoke_from.to_source(),
"score": item.metadata.get("score"), score=item.metadata.get("score"),
"content": item.page_content, content=item.page_content,
} )
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
@ -248,32 +249,32 @@ class DatasetRetrieval:
.first() .first()
) )
if dataset and document: if dataset and document:
source = { source = RetrievalSourceMetadata(
"dataset_id": dataset.id, dataset_id=dataset.id,
"dataset_name": dataset.name, dataset_name=dataset.name,
"document_id": document.id, document_id=document.id,
"document_name": document.name, document_name=document.name,
"data_source_type": document.data_source_type, data_source_type=document.data_source_type,
"segment_id": segment.id, segment_id=segment.id,
"retriever_from": invoke_from.to_source(), retriever_from=invoke_from.to_source(),
"score": record.score or 0.0, score=record.score or 0.0,
"doc_metadata": document.doc_metadata, doc_metadata=document.doc_metadata,
} )
if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count source.hit_count = segment.hit_count
source["word_count"] = segment.word_count source.word_count = segment.word_count
source["segment_position"] = segment.position source.segment_position = segment.position
source["index_node_hash"] = segment.index_node_hash source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content source.content = segment.content
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list: if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
for position, item in enumerate(retrieval_resource_list, start=1): for position, item in enumerate(retrieval_resource_list, start=1):
item["position"] = position item.position = position
hit_callback.return_retriever_resource_info(retrieval_resource_list) hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list: if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)

View File

@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
else: else:
document_context_list.append(segment.get_sign_content()) document_context_list.append(segment.get_sign_content())
if self.return_resource: if self.return_resource:
context_list = [] context_list: list[RetrievalSourceMetadata] = []
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
.first() .first()
) )
if dataset and document: if dataset and document:
source = { source = RetrievalSourceMetadata(
"position": resource_number, position=resource_number,
"dataset_id": dataset.id, dataset_id=dataset.id,
"dataset_name": dataset.name, dataset_name=dataset.name,
"document_id": document.id, document_id=document.id,
"document_name": document.name, document_name=document.name,
"data_source_type": document.data_source_type, data_source_type=document.data_source_type,
"segment_id": segment.id, segment_id=segment.id,
"retriever_from": self.retriever_from, retriever_from=self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None), score=document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata, doc_metadata=document.doc_metadata,
} )
if self.retriever_from == "dev": if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count source.hit_count = segment.hit_count
source["word_count"] = segment.word_count source.word_count = segment.word_count
source["segment_position"] = segment.position source.segment_position = segment.position
source["index_node_hash"] = segment.index_node_hash source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content source.content = segment.content
context_list.append(source) context_list.append(source)
resource_number += 1 resource_number += 1

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -14,7 +15,7 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = { default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False, "reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else: else:
document_ids_filter = None document_ids_filter = None
if dataset.provider == "external": if dataset.provider == "external":
results = [] results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document.metadata["dataset_name"] = dataset.name document.metadata["dataset_name"] = dataset.name
results.append(document) results.append(document)
# deal with external documents # deal with external documents
context_list = [] context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1): for position, item in enumerate(results, start=1):
if item.metadata is not None: if item.metadata is not None:
source = { source = RetrievalSourceMetadata(
"position": position, position=position,
"dataset_id": item.metadata.get("dataset_id"), dataset_id=item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"), dataset_name=item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"), document_id=item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"), document_name=item.metadata.get("title"),
"data_source_type": "external", data_source_type="external",
"retriever_from": self.retriever_from, retriever_from=self.retriever_from,
"score": item.metadata.get("score"), score=item.metadata.get("score"),
"title": item.metadata.get("title"), title=item.metadata.get("title"),
"content": item.page_content, content=item.page_content,
} )
context_list.append(source) context_list.append(source)
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list) hit_callback.return_retriever_resource_info(context_list)
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return "" return ""
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list = [] retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for item in documents: for item in documents:
if item.metadata is not None and item.metadata.get("score"): if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = [] document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents) records = RetrievalService.format_retrieval_documents(documents)
if records: if records:
for record in records: for record in records:
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
.first() .first()
) )
if dataset and document: if dataset and document:
source = { source = RetrievalSourceMetadata(
"dataset_id": dataset.id, dataset_id=dataset.id,
"dataset_name": dataset.name, dataset_name=dataset.name,
"document_id": document.id, # type: ignore document_id=document.id, # type: ignore
"document_name": document.name, # type: ignore document_name=document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore data_source_type=document.data_source_type, # type: ignore
"segment_id": segment.id, segment_id=segment.id,
"retriever_from": self.retriever_from, retriever_from=self.retriever_from,
"score": record.score or 0.0, score=record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore doc_metadata=document.doc_metadata, # type: ignore
} )
if self.retriever_from == "dev": if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count source.hit_count = segment.hit_count
source["word_count"] = segment.word_count source.word_count = segment.word_count
source["segment_position"] = segment.position source.segment_position = segment.position
source["index_node_hash"] = segment.index_node_hash source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content source.content = segment.content
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list: if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted( retrieval_resource_list = sorted(
retrieval_resource_list, retrieval_resource_list,
key=lambda x: x.get("score") or 0.0, key=lambda x: x.score or 0.0,
reverse=True, reverse=True,
) )
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore item.position = position # type: ignore
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list) hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list: if document_context_list:

View File

@ -1,9 +1,10 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
class NodeRunRetrieverResourceEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources") retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")

View File

@ -1,8 +1,10 @@
from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
class RunRetrieverResourceEvent(BaseModel): class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources") retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")

View File

@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import ( from core.variables import (
ArrayAnySegment, ArrayAnySegment,
ArrayFileSegment, ArrayFileSegment,
@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment): elif isinstance(context_value_variable, ArraySegment):
context_str = "" context_str = ""
original_retriever_resource = [] original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value: for item in context_value_variable.value:
if isinstance(item, str): if isinstance(item, str):
context_str += item + "\n" context_str += item + "\n"
@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip() retriever_resources=original_retriever_resource, context=context_str.strip()
) )
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: def _convert_to_original_retriever_resource(self, context_dict: dict):
if ( if (
"metadata" in context_dict "metadata" in context_dict
and "_source" in context_dict["metadata"] and "_source" in context_dict["metadata"]
@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
): ):
metadata = context_dict.get("metadata", {}) metadata = context_dict.get("metadata", {})
source = { source = RetrievalSourceMetadata(
"position": metadata.get("position"), position=metadata.get("position"),
"dataset_id": metadata.get("dataset_id"), dataset_id=metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"), dataset_name=metadata.get("dataset_name"),
"document_id": metadata.get("document_id"), document_id=metadata.get("document_id"),
"document_name": metadata.get("document_name"), document_name=metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"), data_source_type=metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"), segment_id=metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"), retriever_from=metadata.get("retriever_from"),
"score": metadata.get("score"), score=metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"), hit_count=metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"), word_count=metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"), segment_position=metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"), index_node_hash=metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"), content=context_dict.get("content"),
"page": metadata.get("page"), page=metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"), doc_metadata=metadata.get("doc_metadata"),
} )
return source return source