Refactor/message cycle manage and knowledge retrieval (#20460)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-05-30 14:36:44 +08:00 committed by GitHub
parent 5a991295e0
commit a6ea15e63c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 222 additions and 181 deletions

View File

@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator, Mapping
@ -57,10 +56,9 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
@ -141,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
self._task_state = WorkflowTaskState()
self._message_cycle_manager = MessageCycleManage(
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
@ -162,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
:return:
"""
# start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
@ -605,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._message_cycle_manager._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit()
elif isinstance(event, QueueAnnotationReplyEvent):
self._message_cycle_manager._handle_annotation_reply(event)
self._message_cycle_manager.handle_annotation_reply(event)
with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session)
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
session.commit()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
@ -637,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_cycle_manager._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=event.text, reason=event.reason
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
@ -654,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response(
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message = self._get_message(session=session)
message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
message_files = [
MessageFile(
message_id=message.id,
@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage)
self._task_state.metadata.usage = usage
else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
Message end to stream response.
:return:
"""
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
extras = self._task_state.metadata.model_dump()
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
if self._task_state.metadata.annotation_reply:
del extras["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
files=self._recorded_files,
metadata=extras.get("metadata", {}),
metadata=extras,
)
def _handle_output_moderation_chunk(self, text: str) -> bool:

View File

@ -50,7 +50,6 @@ from core.app.entities.task_entities import (
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
)
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._task_state = WorkflowTaskState()
self._workflow_run_id = ""
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)

View File

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
@ -6,6 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None

View File

@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
id: str
name: str
class AnnotationReply(BaseModel):
id: str
account: AnnotationReplyAccount
class TaskStateMetadata(BaseModel):
annotation_reply: AnnotationReply | None = None
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
usage: LLMUsage | None = None
class TaskState(BaseModel):
"""
TaskState entity
"""
metadata: dict = {}
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
class EasyUITaskState(TaskState):

View File

@ -1,4 +1,3 @@
import json
import logging
import time
from collections.abc import Generator
@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity,
task_state=self._task_state,
)
self._conversation_name_generate_thread: Optional[Thread] = None
def process(
@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
]:
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
)
@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=output_moderation_answer
)
with Session(db.engine) as session:
# Save message
@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._message_cycle_manager.handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent):
annotation = self._handle_annotation_reply(event)
annotation = self._message_cycle_manager.handle_annotation_reply(event)
if annotation:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if agent_thought_response is not None:
yield agent_thought_response
elif isinstance(event, QueueMessageFileEvent):
response = self._message_file_to_stream_response(event)
response = self._message_cycle_manager.message_file_to_stream_response(event)
if response:
yield response
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text)
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price
message.currency = usage.currency
message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
trace_manager.add_trace_task(
@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response.
:return:
"""
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=extras.get("metadata", {}),
metadata=metadata_dict,
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:

View File

@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
class MessageCycleManage:
class MessageCycleManager:
def __init__(
self,
*,
@ -45,7 +47,7 @@ class MessageCycleManage:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""
Generate conversation name.
:param conversation_id: conversation id
@ -102,7 +104,7 @@ class MessageCycleManage:
db.session.commit()
db.session.close()
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
"""
Handle annotation reply.
:param event: event
@ -111,25 +113,28 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
self._task_state.metadata.annotation_reply = AnnotationReply(
id=annotation.id,
account=AnnotationReplyAccount(
id=annotation.account_id,
name=account.name if account else "Dify user",
),
)
return annotation
return None
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
"""
Handle retriever resources.
:param event: event
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources
self._task_state.metadata.retriever_resources = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
@ -166,7 +171,7 @@ class MessageCycleManage:
return None
def _message_to_stream_response(
def message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse:
"""
@ -182,7 +187,7 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector,
)
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer

View File

@ -1,8 +1,10 @@
import logging
from collections.abc import Sequence
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from extensions.ext_database import db
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
db.session.commit()
def return_retriever_resource_info(self, resource: list):
# TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info."""
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER

View File

@ -0,0 +1,23 @@
from typing import Any, Optional
from pydantic import BaseModel
class RetrievalSourceMetadata(BaseModel):
position: Optional[int] = None
dataset_id: Optional[str] = None
dataset_name: Optional[str] = None
document_id: Optional[str] = None
document_name: Optional[str] = None
data_source_type: Optional[str] = None
segment_id: Optional[str] = None
retriever_from: Optional[str] = None
score: Optional[float] = None
hit_count: Optional[int] = None
word_count: Optional[int] = None
segment_position: Optional[int] = None
index_node_hash: Optional[str] = None
content: Optional[str] = None
page: Optional[int] = None
doc_metadata: Optional[dict[str, Any]] = None
title: Optional[str] = None

View File

@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
@ -198,21 +199,21 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list = []
retrieval_resource_list = []
document_context_list: list[DocumentContext] = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents
for item in external_documents:
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
source = {
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": invoke_from.to_source(),
"score": item.metadata.get("score"),
"content": item.page_content,
}
source = RetrievalSourceMetadata(
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=invoke_from.to_source(),
score=item.metadata.get("score"),
content=item.page_content,
)
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
@ -248,32 +249,32 @@ class DatasetRetrieval:
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata,
}
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=invoke_from.to_source(),
score=record.score or 0.0,
doc_metadata=document.doc_metadata,
)
if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
source.content = segment.content
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
for position, item in enumerate(retrieval_resource_list, start=1):
item["position"] = position
item.position = position
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)

View File

@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
.first()
)
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata,
}
source = RetrievalSourceMetadata(
position=resource_number,
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from=self.retriever_from,
score=document_score_list.get(segment.index_node_id, None),
doc_metadata=document.doc_metadata,
)
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
source.content = segment.content
context_list.append(source)
resource_number += 1

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -14,7 +15,7 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else:
document_ids_filter = None
if dataset.provider == "external":
results = []
results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1):
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
source = RetrievalSourceMetadata(
position=position,
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from=self.retriever_from,
score=item.metadata.get("score"),
title=item.metadata.get("title"),
content=item.page_content,
)
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list = []
retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for item in documents:
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents)
if records:
for record in records:
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
.first()
)
if dataset and document:
source = {
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id, # type: ignore
"document_name": document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore
}
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id, # type: ignore
document_name=document.name, # type: ignore
data_source_type=document.data_source_type, # type: ignore
segment_id=segment.id,
retriever_from=self.retriever_from,
score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore
)
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
source.hit_count = segment.hit_count
source.word_count = segment.word_count
source.segment_position = segment.position
source.index_node_hash = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
source.content = segment.content
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=lambda x: x.get("score") or 0.0,
key=lambda x: x.score or 0.0,
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore
item.position = position # type: ignore
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:

View File

@ -1,9 +1,10 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@ -1,8 +1,10 @@
from collections.abc import Sequence
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")

View File

@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource = []
original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
def _convert_to_original_retriever_resource(self, context_dict: dict):
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
):
metadata = context_dict.get("metadata", {})
source = {
"position": metadata.get("position"),
"dataset_id": metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"),
"page": metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"),
}
source = RetrievalSourceMetadata(
position=metadata.get("position"),
dataset_id=metadata.get("dataset_id"),
dataset_name=metadata.get("dataset_name"),
document_id=metadata.get("document_id"),
document_name=metadata.get("document_name"),
data_source_type=metadata.get("data_source_type"),
segment_id=metadata.get("segment_id"),
retriever_from=metadata.get("retriever_from"),
score=metadata.get("score"),
hit_count=metadata.get("segment_hit_count"),
word_count=metadata.get("segment_word_count"),
segment_position=metadata.get("segment_position"),
index_node_hash=metadata.get("segment_index_node_hash"),
content=context_dict.get("content"),
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
)
return source