mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-24 05:54:26 +08:00
Refactor/message cycle manage and knowledge retrieval (#20460)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
5a991295e0
commit
a6ea15e63c
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
@ -57,10 +56,9 @@ from core.app.entities.task_entities import (
|
|||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
@ -141,7 +139,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._message_cycle_manager = MessageCycleManage(
|
self._message_cycle_manager = MessageCycleManager(
|
||||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -162,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# start generate conversation name thread
|
# start generate conversation name thread
|
||||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -605,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
yield self._message_end_to_stream_response()
|
yield self._message_end_to_stream_response()
|
||||||
break
|
break
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._message_cycle_manager._handle_retriever_resources(event)
|
self._message_cycle_manager.handle_retriever_resources(event)
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
||||||
)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
self._message_cycle_manager._handle_annotation_reply(event)
|
self._message_cycle_manager.handle_annotation_reply(event)
|
||||||
|
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.message_metadata = (
|
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
||||||
)
|
|
||||||
session.commit()
|
session.commit()
|
||||||
elif isinstance(event, QueueTextChunkEvent):
|
elif isinstance(event, QueueTextChunkEvent):
|
||||||
delta_text = event.text
|
delta_text = event.text
|
||||||
@ -637,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
tts_publisher.publish(queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
self._task_state.answer += delta_text
|
||||||
yield self._message_cycle_manager._message_to_stream_response(
|
yield self._message_cycle_manager.message_to_stream_response(
|
||||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueMessageReplaceEvent):
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
# published by moderation
|
# published by moderation
|
||||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||||
answer=event.text, reason=event.reason
|
answer=event.text, reason=event.reason
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||||
@ -654,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
)
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
self._task_state.answer = output_moderation_answer
|
self._task_state.answer = output_moderation_answer
|
||||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||||
answer=output_moderation_answer,
|
answer=output_moderation_answer,
|
||||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||||
)
|
)
|
||||||
@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
message.answer = self._task_state.answer
|
message.answer = self._task_state.answer
|
||||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||||
message.message_metadata = (
|
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
||||||
)
|
|
||||||
message_files = [
|
message_files = [
|
||||||
MessageFile(
|
MessageFile(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
message.answer_price_unit = usage.completion_price_unit
|
message.answer_price_unit = usage.completion_price_unit
|
||||||
message.total_price = usage.total_price
|
message.total_price = usage.total_price
|
||||||
message.currency = usage.currency
|
message.currency = usage.currency
|
||||||
self._task_state.metadata["usage"] = jsonable_encoder(usage)
|
self._task_state.metadata.usage = usage
|
||||||
else:
|
else:
|
||||||
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
|
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||||
message_was_created.send(
|
message_was_created.send(
|
||||||
message,
|
message,
|
||||||
application_generate_entity=self._application_generate_entity,
|
application_generate_entity=self._application_generate_entity,
|
||||||
@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||||||
Message end to stream response.
|
Message end to stream response.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
extras = {}
|
extras = self._task_state.metadata.model_dump()
|
||||||
if self._task_state.metadata:
|
|
||||||
extras["metadata"] = self._task_state.metadata.copy()
|
|
||||||
|
|
||||||
if "annotation_reply" in extras["metadata"]:
|
if self._task_state.metadata.annotation_reply:
|
||||||
del extras["metadata"]["annotation_reply"]
|
del extras["annotation_reply"]
|
||||||
|
|
||||||
return MessageEndStreamResponse(
|
return MessageEndStreamResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
id=self._message_id,
|
id=self._message_id,
|
||||||
files=self._recorded_files,
|
files=self._recorded_files,
|
||||||
metadata=extras.get("metadata", {}),
|
metadata=extras,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||||
|
@ -50,7 +50,6 @@ from core.app.entities.task_entities import (
|
|||||||
WorkflowAppStreamResponse,
|
WorkflowAppStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
@ -130,9 +129,7 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_id = workflow.id
|
|
||||||
self._workflow_features_dict = workflow.features_dict
|
self._workflow_features_dict = workflow.features_dict
|
||||||
self._task_state = WorkflowTaskState()
|
|
||||||
self._workflow_run_id = ""
|
self._workflow_run_id = ""
|
||||||
|
|
||||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
@ -543,7 +540,6 @@ class WorkflowAppGenerateTaskPipeline:
|
|||||||
if tts_publisher:
|
if tts_publisher:
|
||||||
tts_publisher.publish(queue_message)
|
tts_publisher.publish(queue_message)
|
||||||
|
|
||||||
self._task_state.answer += delta_text
|
|
||||||
yield self._text_chunk_to_stream_response(
|
yield self._text_chunk_to_stream_response(
|
||||||
delta_text, from_variable_selector=event.from_variable_selector
|
delta_text, from_variable_selector=event.from_variable_selector
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@ -6,6 +6,7 @@ from typing import Any, Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||||
retriever_resources: list[dict]
|
retriever_resources: Sequence[RetrievalSourceMetadata]
|
||||||
in_iteration_id: Optional[str] = None
|
in_iteration_id: Optional[str] = None
|
||||||
"""iteration id if node is in iteration"""
|
"""iteration id if node is in iteration"""
|
||||||
in_loop_id: Optional[str] = None
|
in_loop_id: Optional[str] = None
|
||||||
|
@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReplyAccount(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationReply(BaseModel):
|
||||||
|
id: str
|
||||||
|
account: AnnotationReplyAccount
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStateMetadata(BaseModel):
|
||||||
|
annotation_reply: AnnotationReply | None = None
|
||||||
|
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
|
||||||
|
usage: LLMUsage | None = None
|
||||||
|
|
||||||
|
|
||||||
class TaskState(BaseModel):
|
class TaskState(BaseModel):
|
||||||
"""
|
"""
|
||||||
TaskState entity
|
TaskState entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metadata: dict = {}
|
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
|
||||||
|
|
||||||
|
|
||||||
class EasyUITaskState(TaskState):
|
class EasyUITaskState(TaskState):
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
@ -43,7 +42,7 @@ from core.app.entities.task_entities import (
|
|||||||
StreamResponse,
|
StreamResponse,
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
|
|||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
@ -63,7 +61,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
|
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||||
"""
|
"""
|
||||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
@ -104,6 +102,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._message_cycle_manager = MessageCycleManager(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
task_state=self._task_state,
|
||||||
|
)
|
||||||
|
|
||||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
@ -115,7 +118,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
]:
|
]:
|
||||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||||
# start generate conversation name thread
|
# start generate conversation name thread
|
||||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -136,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
if isinstance(stream_response, ErrorStreamResponse):
|
if isinstance(stream_response, ErrorStreamResponse):
|
||||||
raise stream_response.err
|
raise stream_response.err
|
||||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||||
if self._task_state.metadata:
|
if self._task_state.metadata:
|
||||||
extras["metadata"] = self._task_state.metadata
|
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||||
response = CompletionAppBlockingResponse(
|
response = CompletionAppBlockingResponse(
|
||||||
@ -277,7 +280,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
)
|
)
|
||||||
if output_moderation_answer:
|
if output_moderation_answer:
|
||||||
self._task_state.llm_result.message.content = output_moderation_answer
|
self._task_state.llm_result.message.content = output_moderation_answer
|
||||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||||
|
answer=output_moderation_answer
|
||||||
|
)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
# Save message
|
# Save message
|
||||||
@ -286,9 +291,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
message_end_resp = self._message_end_to_stream_response()
|
message_end_resp = self._message_end_to_stream_response()
|
||||||
yield message_end_resp
|
yield message_end_resp
|
||||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||||
self._handle_retriever_resources(event)
|
self._message_cycle_manager.handle_retriever_resources(event)
|
||||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||||
annotation = self._handle_annotation_reply(event)
|
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||||
if annotation:
|
if annotation:
|
||||||
self._task_state.llm_result.message.content = annotation.content
|
self._task_state.llm_result.message.content = annotation.content
|
||||||
elif isinstance(event, QueueAgentThoughtEvent):
|
elif isinstance(event, QueueAgentThoughtEvent):
|
||||||
@ -296,7 +301,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
if agent_thought_response is not None:
|
if agent_thought_response is not None:
|
||||||
yield agent_thought_response
|
yield agent_thought_response
|
||||||
elif isinstance(event, QueueMessageFileEvent):
|
elif isinstance(event, QueueMessageFileEvent):
|
||||||
response = self._message_file_to_stream_response(event)
|
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||||
if response:
|
if response:
|
||||||
yield response
|
yield response
|
||||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||||
@ -318,7 +323,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
self._task_state.llm_result.message.content = current_content
|
self._task_state.llm_result.message.content = current_content
|
||||||
|
|
||||||
if isinstance(event, QueueLLMChunkEvent):
|
if isinstance(event, QueueLLMChunkEvent):
|
||||||
yield self._message_to_stream_response(
|
yield self._message_cycle_manager.message_to_stream_response(
|
||||||
answer=cast(str, delta_text),
|
answer=cast(str, delta_text),
|
||||||
message_id=self._message_id,
|
message_id=self._message_id,
|
||||||
)
|
)
|
||||||
@ -328,7 +333,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
message_id=self._message_id,
|
message_id=self._message_id,
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueMessageReplaceEvent):
|
elif isinstance(event, QueueMessageReplaceEvent):
|
||||||
yield self._message_replace_to_stream_response(answer=event.text)
|
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||||
elif isinstance(event, QueuePingEvent):
|
elif isinstance(event, QueuePingEvent):
|
||||||
yield self._ping_stream_response()
|
yield self._ping_stream_response()
|
||||||
else:
|
else:
|
||||||
@ -372,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||||
message.total_price = usage.total_price
|
message.total_price = usage.total_price
|
||||||
message.currency = usage.currency
|
message.currency = usage.currency
|
||||||
message.message_metadata = (
|
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if trace_manager:
|
if trace_manager:
|
||||||
trace_manager.add_trace_task(
|
trace_manager.add_trace_task(
|
||||||
@ -423,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||||||
Message end to stream response.
|
Message end to stream response.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||||
|
metadata_dict = self._task_state.metadata.model_dump()
|
||||||
extras = {}
|
|
||||||
if self._task_state.metadata:
|
|
||||||
extras["metadata"] = self._task_state.metadata
|
|
||||||
|
|
||||||
return MessageEndStreamResponse(
|
return MessageEndStreamResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
id=self._message_id,
|
id=self._message_id,
|
||||||
metadata=extras.get("metadata", {}),
|
metadata=metadata_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||||
|
@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
|
AnnotationReply,
|
||||||
|
AnnotationReplyAccount,
|
||||||
EasyUITaskState,
|
EasyUITaskState,
|
||||||
MessageFileStreamResponse,
|
MessageFileStreamResponse,
|
||||||
MessageReplaceStreamResponse,
|
MessageReplaceStreamResponse,
|
||||||
@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
|||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
|
|
||||||
|
|
||||||
class MessageCycleManage:
|
class MessageCycleManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -45,7 +47,7 @@ class MessageCycleManage:
|
|||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._task_state = task_state
|
self._task_state = task_state
|
||||||
|
|
||||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||||
"""
|
"""
|
||||||
Generate conversation name.
|
Generate conversation name.
|
||||||
:param conversation_id: conversation id
|
:param conversation_id: conversation id
|
||||||
@ -102,7 +104,7 @@ class MessageCycleManage:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||||
"""
|
"""
|
||||||
Handle annotation reply.
|
Handle annotation reply.
|
||||||
:param event: event
|
:param event: event
|
||||||
@ -111,25 +113,28 @@ class MessageCycleManage:
|
|||||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||||
if annotation:
|
if annotation:
|
||||||
account = annotation.account
|
account = annotation.account
|
||||||
self._task_state.metadata["annotation_reply"] = {
|
self._task_state.metadata.annotation_reply = AnnotationReply(
|
||||||
"id": annotation.id,
|
id=annotation.id,
|
||||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
account=AnnotationReplyAccount(
|
||||||
}
|
id=annotation.account_id,
|
||||||
|
name=account.name if account else "Dify user",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||||
"""
|
"""
|
||||||
Handle retriever resources.
|
Handle retriever resources.
|
||||||
:param event: event
|
:param event: event
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||||
|
|
||||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||||
"""
|
"""
|
||||||
Message file to stream response.
|
Message file to stream response.
|
||||||
:param event: event
|
:param event: event
|
||||||
@ -166,7 +171,7 @@ class MessageCycleManage:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _message_to_stream_response(
|
def message_to_stream_response(
|
||||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||||
) -> MessageStreamResponse:
|
) -> MessageStreamResponse:
|
||||||
"""
|
"""
|
||||||
@ -182,7 +187,7 @@ class MessageCycleManage:
|
|||||||
from_variable_selector=from_variable_selector,
|
from_variable_selector=from_variable_selector,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||||
"""
|
"""
|
||||||
Message replace to stream response.
|
Message replace to stream response.
|
||||||
:param answer: answer
|
:param answer: answer
|
@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def return_retriever_resource_info(self, resource: list):
|
# TODO(-LAN-): Improve type check
|
||||||
|
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||||
"""Handle return_retriever_resource_info."""
|
"""Handle return_retriever_resource_info."""
|
||||||
self._queue_manager.publish(
|
self._queue_manager.publish(
|
||||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||||
|
23
api/core/rag/entities/citation_metadata.py
Normal file
23
api/core/rag/entities/citation_metadata.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalSourceMetadata(BaseModel):
|
||||||
|
position: Optional[int] = None
|
||||||
|
dataset_id: Optional[str] = None
|
||||||
|
dataset_name: Optional[str] = None
|
||||||
|
document_id: Optional[str] = None
|
||||||
|
document_name: Optional[str] = None
|
||||||
|
data_source_type: Optional[str] = None
|
||||||
|
segment_id: Optional[str] = None
|
||||||
|
retriever_from: Optional[str] = None
|
||||||
|
score: Optional[float] = None
|
||||||
|
hit_count: Optional[int] = None
|
||||||
|
word_count: Optional[int] = None
|
||||||
|
segment_position: Optional[int] = None
|
||||||
|
index_node_hash: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
page: Optional[int] = None
|
||||||
|
doc_metadata: Optional[dict[str, Any]] = None
|
||||||
|
title: Optional[str] = None
|
@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
@ -198,21 +199,21 @@ class DatasetRetrieval:
|
|||||||
|
|
||||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||||
document_context_list = []
|
document_context_list: list[DocumentContext] = []
|
||||||
retrieval_resource_list = []
|
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||||
# deal with external documents
|
# deal with external documents
|
||||||
for item in external_documents:
|
for item in external_documents:
|
||||||
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
|
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
|
||||||
source = {
|
source = RetrievalSourceMetadata(
|
||||||
"dataset_id": item.metadata.get("dataset_id"),
|
dataset_id=item.metadata.get("dataset_id"),
|
||||||
"dataset_name": item.metadata.get("dataset_name"),
|
dataset_name=item.metadata.get("dataset_name"),
|
||||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||||
"document_name": item.metadata.get("title"),
|
document_name=item.metadata.get("title"),
|
||||||
"data_source_type": "external",
|
data_source_type="external",
|
||||||
"retriever_from": invoke_from.to_source(),
|
retriever_from=invoke_from.to_source(),
|
||||||
"score": item.metadata.get("score"),
|
score=item.metadata.get("score"),
|
||||||
"content": item.page_content,
|
content=item.page_content,
|
||||||
}
|
)
|
||||||
retrieval_resource_list.append(source)
|
retrieval_resource_list.append(source)
|
||||||
# deal with dify documents
|
# deal with dify documents
|
||||||
if dify_documents:
|
if dify_documents:
|
||||||
@ -248,32 +249,32 @@ class DatasetRetrieval:
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = {
|
source = RetrievalSourceMetadata(
|
||||||
"dataset_id": dataset.id,
|
dataset_id=dataset.id,
|
||||||
"dataset_name": dataset.name,
|
dataset_name=dataset.name,
|
||||||
"document_id": document.id,
|
document_id=document.id,
|
||||||
"document_name": document.name,
|
document_name=document.name,
|
||||||
"data_source_type": document.data_source_type,
|
data_source_type=document.data_source_type,
|
||||||
"segment_id": segment.id,
|
segment_id=segment.id,
|
||||||
"retriever_from": invoke_from.to_source(),
|
retriever_from=invoke_from.to_source(),
|
||||||
"score": record.score or 0.0,
|
score=record.score or 0.0,
|
||||||
"doc_metadata": document.doc_metadata,
|
doc_metadata=document.doc_metadata,
|
||||||
}
|
)
|
||||||
|
|
||||||
if invoke_from.to_source() == "dev":
|
if invoke_from.to_source() == "dev":
|
||||||
source["hit_count"] = segment.hit_count
|
source.hit_count = segment.hit_count
|
||||||
source["word_count"] = segment.word_count
|
source.word_count = segment.word_count
|
||||||
source["segment_position"] = segment.position
|
source.segment_position = segment.position
|
||||||
source["index_node_hash"] = segment.index_node_hash
|
source.index_node_hash = segment.index_node_hash
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||||
else:
|
else:
|
||||||
source["content"] = segment.content
|
source.content = segment.content
|
||||||
retrieval_resource_list.append(source)
|
retrieval_resource_list.append(source)
|
||||||
if hit_callback and retrieval_resource_list:
|
if hit_callback and retrieval_resource_list:
|
||||||
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True)
|
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||||
item["position"] = position
|
item.position = position
|
||||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||||
if document_context_list:
|
if document_context_list:
|
||||||
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)
|
||||||
|
@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
|||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.models.document import Document as RagDocument
|
from core.rag.models.document import Document as RagDocument
|
||||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
else:
|
else:
|
||||||
document_context_list.append(segment.get_sign_content())
|
document_context_list.append(segment.get_sign_content())
|
||||||
if self.return_resource:
|
if self.return_resource:
|
||||||
context_list = []
|
context_list: list[RetrievalSourceMetadata] = []
|
||||||
resource_number = 1
|
resource_number = 1
|
||||||
for segment in sorted_segments:
|
for segment in sorted_segments:
|
||||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||||
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = {
|
source = RetrievalSourceMetadata(
|
||||||
"position": resource_number,
|
position=resource_number,
|
||||||
"dataset_id": dataset.id,
|
dataset_id=dataset.id,
|
||||||
"dataset_name": dataset.name,
|
dataset_name=dataset.name,
|
||||||
"document_id": document.id,
|
document_id=document.id,
|
||||||
"document_name": document.name,
|
document_name=document.name,
|
||||||
"data_source_type": document.data_source_type,
|
data_source_type=document.data_source_type,
|
||||||
"segment_id": segment.id,
|
segment_id=segment.id,
|
||||||
"retriever_from": self.retriever_from,
|
retriever_from=self.retriever_from,
|
||||||
"score": document_score_list.get(segment.index_node_id, None),
|
score=document_score_list.get(segment.index_node_id, None),
|
||||||
"doc_metadata": document.doc_metadata,
|
doc_metadata=document.doc_metadata,
|
||||||
}
|
)
|
||||||
|
|
||||||
if self.retriever_from == "dev":
|
if self.retriever_from == "dev":
|
||||||
source["hit_count"] = segment.hit_count
|
source.hit_count = segment.hit_count
|
||||||
source["word_count"] = segment.word_count
|
source.word_count = segment.word_count
|
||||||
source["segment_position"] = segment.position
|
source.segment_position = segment.position
|
||||||
source["index_node_hash"] = segment.index_node_hash
|
source.index_node_hash = segment.index_node_hash
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||||
else:
|
else:
|
||||||
source["content"] = segment.content
|
source.content = segment.content
|
||||||
context_list.append(source)
|
context_list.append(source)
|
||||||
resource_number += 1
|
resource_number += 1
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||||
from core.rag.datasource.retrieval_service import RetrievalService
|
from core.rag.datasource.retrieval_service import RetrievalService
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.rag.entities.context_entities import DocumentContext
|
from core.rag.entities.context_entities import DocumentContext
|
||||||
from core.rag.models.document import Document as RetrievalDocument
|
from core.rag.models.document import Document as RetrievalDocument
|
||||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||||
@ -14,7 +15,7 @@ from models.dataset import Dataset
|
|||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from services.external_knowledge_service import ExternalDatasetService
|
from services.external_knowledge_service import ExternalDatasetService
|
||||||
|
|
||||||
default_retrieval_model = {
|
default_retrieval_model: dict[str, Any] = {
|
||||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
"reranking_enable": False,
|
"reranking_enable": False,
|
||||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
else:
|
else:
|
||||||
document_ids_filter = None
|
document_ids_filter = None
|
||||||
if dataset.provider == "external":
|
if dataset.provider == "external":
|
||||||
results = []
|
results: list[RetrievalDocument] = []
|
||||||
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
document.metadata["dataset_name"] = dataset.name
|
document.metadata["dataset_name"] = dataset.name
|
||||||
results.append(document)
|
results.append(document)
|
||||||
# deal with external documents
|
# deal with external documents
|
||||||
context_list = []
|
context_list: list[RetrievalSourceMetadata] = []
|
||||||
for position, item in enumerate(results, start=1):
|
for position, item in enumerate(results, start=1):
|
||||||
if item.metadata is not None:
|
if item.metadata is not None:
|
||||||
source = {
|
source = RetrievalSourceMetadata(
|
||||||
"position": position,
|
position=position,
|
||||||
"dataset_id": item.metadata.get("dataset_id"),
|
dataset_id=item.metadata.get("dataset_id"),
|
||||||
"dataset_name": item.metadata.get("dataset_name"),
|
dataset_name=item.metadata.get("dataset_name"),
|
||||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
document_id=item.metadata.get("document_id") or item.metadata.get("title"),
|
||||||
"document_name": item.metadata.get("title"),
|
document_name=item.metadata.get("title"),
|
||||||
"data_source_type": "external",
|
data_source_type="external",
|
||||||
"retriever_from": self.retriever_from,
|
retriever_from=self.retriever_from,
|
||||||
"score": item.metadata.get("score"),
|
score=item.metadata.get("score"),
|
||||||
"title": item.metadata.get("title"),
|
title=item.metadata.get("title"),
|
||||||
"content": item.page_content,
|
content=item.page_content,
|
||||||
}
|
)
|
||||||
context_list.append(source)
|
context_list.append(source)
|
||||||
for hit_callback in self.hit_callbacks:
|
for hit_callback in self.hit_callbacks:
|
||||||
hit_callback.return_retriever_resource_info(context_list)
|
hit_callback.return_retriever_resource_info(context_list)
|
||||||
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
return ""
|
return ""
|
||||||
# get retrieval model , if the model is not setting , using default
|
# get retrieval model , if the model is not setting , using default
|
||||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||||
retrieval_resource_list = []
|
retrieval_resource_list: list[RetrievalSourceMetadata] = []
|
||||||
if dataset.indexing_technique == "economy":
|
if dataset.indexing_technique == "economy":
|
||||||
# use keyword table query
|
# use keyword table query
|
||||||
documents = RetrievalService.retrieve(
|
documents = RetrievalService.retrieve(
|
||||||
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
for item in documents:
|
for item in documents:
|
||||||
if item.metadata is not None and item.metadata.get("score"):
|
if item.metadata is not None and item.metadata.get("score"):
|
||||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||||
document_context_list = []
|
document_context_list: list[DocumentContext] = []
|
||||||
records = RetrievalService.format_retrieval_documents(documents)
|
records = RetrievalService.format_retrieval_documents(documents)
|
||||||
if records:
|
if records:
|
||||||
for record in records:
|
for record in records:
|
||||||
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = {
|
source = RetrievalSourceMetadata(
|
||||||
"dataset_id": dataset.id,
|
dataset_id=dataset.id,
|
||||||
"dataset_name": dataset.name,
|
dataset_name=dataset.name,
|
||||||
"document_id": document.id, # type: ignore
|
document_id=document.id, # type: ignore
|
||||||
"document_name": document.name, # type: ignore
|
document_name=document.name, # type: ignore
|
||||||
"data_source_type": document.data_source_type, # type: ignore
|
data_source_type=document.data_source_type, # type: ignore
|
||||||
"segment_id": segment.id,
|
segment_id=segment.id,
|
||||||
"retriever_from": self.retriever_from,
|
retriever_from=self.retriever_from,
|
||||||
"score": record.score or 0.0,
|
score=record.score or 0.0,
|
||||||
"doc_metadata": document.doc_metadata, # type: ignore
|
doc_metadata=document.doc_metadata, # type: ignore
|
||||||
}
|
)
|
||||||
|
|
||||||
if self.retriever_from == "dev":
|
if self.retriever_from == "dev":
|
||||||
source["hit_count"] = segment.hit_count
|
source.hit_count = segment.hit_count
|
||||||
source["word_count"] = segment.word_count
|
source.word_count = segment.word_count
|
||||||
source["segment_position"] = segment.position
|
source.segment_position = segment.position
|
||||||
source["index_node_hash"] = segment.index_node_hash
|
source.index_node_hash = segment.index_node_hash
|
||||||
if segment.answer:
|
if segment.answer:
|
||||||
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
|
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
|
||||||
else:
|
else:
|
||||||
source["content"] = segment.content
|
source.content = segment.content
|
||||||
retrieval_resource_list.append(source)
|
retrieval_resource_list.append(source)
|
||||||
|
|
||||||
if self.return_resource and retrieval_resource_list:
|
if self.return_resource and retrieval_resource_list:
|
||||||
retrieval_resource_list = sorted(
|
retrieval_resource_list = sorted(
|
||||||
retrieval_resource_list,
|
retrieval_resource_list,
|
||||||
key=lambda x: x.get("score") or 0.0,
|
key=lambda x: x.score or 0.0,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
|
||||||
item["position"] = position # type: ignore
|
item.position = position # type: ignore
|
||||||
for hit_callback in self.hit_callbacks:
|
for hit_callback in self.hit_callbacks:
|
||||||
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
hit_callback.return_retriever_resource_info(retrieval_resource_list)
|
||||||
if document_context_list:
|
if document_context_list:
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
|
|||||||
|
|
||||||
|
|
||||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||||
context: str = Field(..., description="context")
|
context: str = Field(..., description="context")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class RunRetrieverResourceEvent(BaseModel):
|
class RunRetrieverResourceEvent(BaseModel):
|
||||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||||
context: str = Field(..., description="context")
|
context: str = Field(..., description="context")
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
ArrayFileSegment,
|
ArrayFileSegment,
|
||||||
@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||||
elif isinstance(context_value_variable, ArraySegment):
|
elif isinstance(context_value_variable, ArraySegment):
|
||||||
context_str = ""
|
context_str = ""
|
||||||
original_retriever_resource = []
|
original_retriever_resource: list[RetrievalSourceMetadata] = []
|
||||||
for item in context_value_variable.value:
|
for item in context_value_variable.value:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
context_str += item + "\n"
|
context_str += item + "\n"
|
||||||
@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
retriever_resources=original_retriever_resource, context=context_str.strip()
|
retriever_resources=original_retriever_resource, context=context_str.strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
|
def _convert_to_original_retriever_resource(self, context_dict: dict):
|
||||||
if (
|
if (
|
||||||
"metadata" in context_dict
|
"metadata" in context_dict
|
||||||
and "_source" in context_dict["metadata"]
|
and "_source" in context_dict["metadata"]
|
||||||
@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
):
|
):
|
||||||
metadata = context_dict.get("metadata", {})
|
metadata = context_dict.get("metadata", {})
|
||||||
|
|
||||||
source = {
|
source = RetrievalSourceMetadata(
|
||||||
"position": metadata.get("position"),
|
position=metadata.get("position"),
|
||||||
"dataset_id": metadata.get("dataset_id"),
|
dataset_id=metadata.get("dataset_id"),
|
||||||
"dataset_name": metadata.get("dataset_name"),
|
dataset_name=metadata.get("dataset_name"),
|
||||||
"document_id": metadata.get("document_id"),
|
document_id=metadata.get("document_id"),
|
||||||
"document_name": metadata.get("document_name"),
|
document_name=metadata.get("document_name"),
|
||||||
"data_source_type": metadata.get("data_source_type"),
|
data_source_type=metadata.get("data_source_type"),
|
||||||
"segment_id": metadata.get("segment_id"),
|
segment_id=metadata.get("segment_id"),
|
||||||
"retriever_from": metadata.get("retriever_from"),
|
retriever_from=metadata.get("retriever_from"),
|
||||||
"score": metadata.get("score"),
|
score=metadata.get("score"),
|
||||||
"hit_count": metadata.get("segment_hit_count"),
|
hit_count=metadata.get("segment_hit_count"),
|
||||||
"word_count": metadata.get("segment_word_count"),
|
word_count=metadata.get("segment_word_count"),
|
||||||
"segment_position": metadata.get("segment_position"),
|
segment_position=metadata.get("segment_position"),
|
||||||
"index_node_hash": metadata.get("segment_index_node_hash"),
|
index_node_hash=metadata.get("segment_index_node_hash"),
|
||||||
"content": context_dict.get("content"),
|
content=context_dict.get("content"),
|
||||||
"page": metadata.get("page"),
|
page=metadata.get("page"),
|
||||||
"doc_metadata": metadata.get("doc_metadata"),
|
doc_metadata=metadata.get("doc_metadata"),
|
||||||
}
|
)
|
||||||
|
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user