feat: support agent log event

This commit is contained in:
Yeuoly 2024-12-12 23:46:26 +08:00
parent dedc1b0c3a
commit 9a6f120e5c
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
11 changed files with 181 additions and 13 deletions

View File

@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent, QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
@ -124,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {} self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = [] self._recorded_files: list[Mapping[str, Any]] = []
@ -244,7 +246,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else: else:
start_listener_time = time.time() start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception:
logger.exception(f"Failed to listen audio message, task_id: {task_id}") logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break break
if tts_publisher: if tts_publisher:
@ -493,6 +495,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._save_message(graph_runtime_state=graph_runtime_state) self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
else: else:
continue continue

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
@ -106,6 +107,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {} self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self.total_tokens: int = 0 self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -216,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
break break
else: else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e: except Exception:
logger.exception(f"Fails to get audio trunk, task_id: {task_id}") logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break break
if tts_publisher: if tts_publisher:
@ -387,6 +389,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield self._text_chunk_to_stream_response( yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector delta_text, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
else: else:
continue continue

View File

@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
QueueIterationStartEvent, QueueIterationStartEvent,
@ -23,6 +24,7 @@ from core.app.entities.queue_entities import (
) )
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent, GraphEngineEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
@ -295,6 +297,17 @@ class WorkflowBasedAppRunner(AppRunner):
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
) )
) )
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent): elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event( self._publish_event(
QueueParallelBranchRunStartedEvent( QueueParallelBranchRunStartedEvent(

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Mapping, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -38,6 +38,7 @@ class QueueEvent(StrEnum):
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started" PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded" PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed" PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error" ERROR = "error"
PING = "ping" PING = "ping"
STOP = "stop" STOP = "stop"
@ -300,6 +301,20 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None iteration_duration_map: Optional[dict[str, float]] = None
class QueueAgentLogEvent(AppQueueEvent):
"""
QueueAgentLogEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_LOG
id: str
node_execution_id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
class QueueNodeInIterationFailedEvent(AppQueueEvent): class QueueNodeInIterationFailedEvent(AppQueueEvent):
""" """
QueueNodeInIterationFailedEvent entity QueueNodeInIterationFailedEvent entity

View File

@ -59,6 +59,7 @@ class StreamEvent(Enum):
ITERATION_COMPLETED = "iteration_completed" ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk" TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace" TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
class StreamResponse(BaseModel): class StreamResponse(BaseModel):
@ -625,3 +626,24 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str workflow_run_id: str
data: Data data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
node_execution_id: str
id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
event: StreamEvent = StreamEvent.AGENT_LOG
data: Data

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
QueueIterationNextEvent, QueueIterationNextEvent,
QueueIterationStartEvent, QueueIterationStartEvent,
@ -21,6 +22,7 @@ from core.app.entities.queue_entities import (
QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunSucceededEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse, IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse, IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse, IterationNodeStartStreamResponse,
@ -63,6 +65,7 @@ class WorkflowCycleManage:
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]]
def _handle_workflow_run_start(self) -> WorkflowRun: def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = ( max_sequence = (
@ -283,9 +286,16 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = ( execution_metadata_dict = event.execution_metadata
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None if self._wip_workflow_agent_logs.get(event.node_execution_id):
) if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
event.node_execution_id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
@ -332,9 +342,16 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = ( execution_metadata_dict = event.execution_metadata
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None if self._wip_workflow_agent_logs.get(event.node_execution_id):
) if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
event.node_execution_id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{ {
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
@ -746,3 +763,52 @@ class WorkflowCycleManage:
raise Exception(f"Workflow node execution not found: {node_execution_id}") raise Exception(f"Workflow node execution not found: {node_execution_id}")
return workflow_node_execution return workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
node_execution = self._wip_workflow_node_executions.get(event.node_execution_id)
if not node_execution:
raise Exception(f"Workflow node execution not found: {event.node_execution_id}")
node_execution_id = node_execution.id
original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, [])
# try to find the log with the same id
for log in original_agent_logs:
if log.id == event.id:
# update the log
log.status = event.status
log.error = event.error
log.data = event.data
break
else:
# append the log
original_agent_logs.append(
AgentLogStreamResponse.Data(
id=event.id,
parent_id=event.parent_id,
node_execution_id=node_execution_id,
error=event.error,
status=event.status,
data=event.data,
)
)
self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=node_execution_id,
id=event.id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
),
)

View File

@ -1,7 +1,7 @@
import base64 import base64
import enum import enum
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Mapping, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
@ -150,6 +150,18 @@ class ToolInvokeMessage(BaseModel):
raise ValueError(f"The variable name '{value}' is reserved.") raise ValueError(f"The variable name '{value}' is reserved.")
return value return value
class LogMessage(BaseModel):
class LogStatus(Enum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
error: Optional[str] = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
class MessageType(Enum): class MessageType(Enum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
@ -160,12 +172,13 @@ class ToolInvokeMessage(BaseModel):
BINARY_LINK = "binary_link" BINARY_LINK = "binary_link"
VARIABLE = "variable" VARIABLE = "variable"
FILE = "file" FILE = "file"
LOG = "log"
type: MessageType = MessageType.TEXT type: MessageType = MessageType.TEXT
""" """
plain text, image url or link url plain text, image url or link url
""" """
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
meta: dict[str, Any] | None = None meta: dict[str, Any] | None = None
@field_validator("message", mode="before") @field_validator("message", mode="before")

View File

@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
TOTAL_PRICE = "total_price" TOTAL_PRICE = "total_price"
CURRENCY = "currency" CURRENCY = "currency"
TOOL_INFO = "tool_info" TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id" ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index" ITERATION_INDEX = "iteration_index"
PARALLEL_ID = "parallel_id" PARALLEL_ID = "parallel_id"

View File

@ -170,4 +170,22 @@ class IterationRunFailedEvent(BaseIterationEvent):
error: str = Field(..., description="failed reason") error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent ###########################################
# Agent Events
###########################################
class BaseAgentEvent(GraphEngineEvent):
pass
class AgentLogEvent(BaseAgentEvent):
id: str = Field(..., description="id")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent

View File

@ -152,7 +152,7 @@ class GraphEngine:
elif isinstance(item, NodeRunSucceededEvent): elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END: if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = ( self.graph_runtime_state.outputs = (
item.route_node_state.node_run_result.outputs dict(item.route_node_state.node_run_result.outputs)
if item.route_node_state.node_run_result if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs and item.route_node_state.node_run_result.outputs
else {} else {}

View File

@ -16,6 +16,7 @@ from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
@ -55,6 +56,17 @@ class ToolNode(BaseNode[ToolNodeData]):
"plugin_unique_identifier": node_data.plugin_unique_identifier, "plugin_unique_identifier": node_data.plugin_unique_identifier,
} }
yield AgentLogEvent(
id=self.node_id,
node_execution_id=self.id,
parent_id=None,
error=None,
status="running",
data={
"tool_info": tool_info,
},
)
# get tool runtime # get tool runtime
try: try:
tool_runtime = ToolManager.get_workflow_tool_runtime( tool_runtime = ToolManager.get_workflow_tool_runtime(