mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 06:45:57 +08:00
feat: support agent log event
This commit is contained in:
parent
dedc1b0c3a
commit
9a6f120e5c
@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
@ -124,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._wip_workflow_agent_logs = {}
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
@ -244,7 +246,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -493,6 +495,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
@ -106,6 +107,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
self._wip_workflow_agent_logs = {}
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@ -216,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
break
|
||||
else:
|
||||
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
|
||||
break
|
||||
if tts_publisher:
|
||||
@ -387,6 +389,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueAgentLogEvent):
|
||||
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@ -23,6 +24,7 @@ from core.app.entities.queue_entities import (
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
@ -295,6 +297,17 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
self._publish_event(
|
||||
QueueAgentLogEvent(
|
||||
id=event.id,
|
||||
node_execution_id=event.node_execution_id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
|
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -38,6 +38,7 @@ class QueueEvent(StrEnum):
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
AGENT_LOG = "agent_log"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@ -300,6 +301,20 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentLogEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_LOG
|
||||
id: str
|
||||
node_execution_id: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
|
@ -59,6 +59,7 @@ class StreamEvent(Enum):
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@ -625,3 +626,24 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class AgentLogStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentLogStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_execution_id: str
|
||||
id: str
|
||||
parent_id: str | None
|
||||
error: str | None
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@ -21,6 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
@ -63,6 +65,7 @@ class WorkflowCycleManage:
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
_wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
@ -283,9 +286,16 @@ class WorkflowCycleManage:
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
execution_metadata_dict = event.execution_metadata
|
||||
if self._wip_workflow_agent_logs.get(event.node_execution_id):
|
||||
if not execution_metadata_dict:
|
||||
execution_metadata_dict = {}
|
||||
|
||||
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
|
||||
event.node_execution_id, []
|
||||
)
|
||||
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
@ -332,9 +342,16 @@ class WorkflowCycleManage:
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
execution_metadata_dict = event.execution_metadata
|
||||
if self._wip_workflow_agent_logs.get(event.node_execution_id):
|
||||
if not execution_metadata_dict:
|
||||
execution_metadata_dict = {}
|
||||
|
||||
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
|
||||
event.node_execution_id, []
|
||||
)
|
||||
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
@ -746,3 +763,52 @@ class WorkflowCycleManage:
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
Handle agent log
|
||||
:param task_id: task id
|
||||
:param event: agent log event
|
||||
:return:
|
||||
"""
|
||||
node_execution = self._wip_workflow_node_executions.get(event.node_execution_id)
|
||||
if not node_execution:
|
||||
raise Exception(f"Workflow node execution not found: {event.node_execution_id}")
|
||||
|
||||
node_execution_id = node_execution.id
|
||||
original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, [])
|
||||
|
||||
# try to find the log with the same id
|
||||
for log in original_agent_logs:
|
||||
if log.id == event.id:
|
||||
# update the log
|
||||
log.status = event.status
|
||||
log.error = event.error
|
||||
log.data = event.data
|
||||
break
|
||||
else:
|
||||
# append the log
|
||||
original_agent_logs.append(
|
||||
AgentLogStreamResponse.Data(
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
node_execution_id=node_execution_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
)
|
||||
)
|
||||
|
||||
self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs
|
||||
|
||||
return AgentLogStreamResponse(
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=node_execution_id,
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
),
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
|
||||
|
||||
@ -150,6 +150,18 @@ class ToolInvokeMessage(BaseModel):
|
||||
raise ValueError(f"The variable name '{value}' is reserved.")
|
||||
return value
|
||||
|
||||
class LogMessage(BaseModel):
|
||||
class LogStatus(Enum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str
|
||||
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
|
||||
error: Optional[str] = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
|
||||
class MessageType(Enum):
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
@ -160,12 +172,13 @@ class ToolInvokeMessage(BaseModel):
|
||||
BINARY_LINK = "binary_link"
|
||||
VARIABLE = "variable"
|
||||
FILE = "file"
|
||||
LOG = "log"
|
||||
|
||||
type: MessageType = MessageType.TEXT
|
||||
"""
|
||||
plain text, image url or link url
|
||||
"""
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
|
||||
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
|
||||
meta: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("message", mode="before")
|
||||
|
@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
|
||||
TOTAL_PRICE = "total_price"
|
||||
CURRENCY = "currency"
|
||||
TOOL_INFO = "tool_info"
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
|
@ -170,4 +170,22 @@ class IterationRunFailedEvent(BaseIterationEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||
###########################################
|
||||
# Agent Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseAgentEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class AgentLogEvent(BaseAgentEvent):
|
||||
id: str = Field(..., description="id")
|
||||
node_execution_id: str = Field(..., description="node execution id")
|
||||
parent_id: str | None = Field(..., description="parent id")
|
||||
error: str | None = Field(..., description="error")
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
||||
|
@ -152,7 +152,7 @@ class GraphEngine:
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
self.graph_runtime_state.outputs = (
|
||||
item.route_node_state.node_run_result.outputs
|
||||
dict(item.route_node_state.node_run_result.outputs)
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {}
|
||||
|
@ -16,6 +16,7 @@ from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
@ -55,6 +56,17 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
|
||||
yield AgentLogEvent(
|
||||
id=self.node_id,
|
||||
node_execution_id=self.id,
|
||||
parent_id=None,
|
||||
error=None,
|
||||
status="running",
|
||||
data={
|
||||
"tool_info": tool_info,
|
||||
},
|
||||
)
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
|
Loading…
x
Reference in New Issue
Block a user