feat: support agent log event

This commit is contained in:
Yeuoly 2024-12-12 23:46:26 +08:00
parent dedc1b0c3a
commit 9a6f120e5c
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
11 changed files with 181 additions and 13 deletions

View File

@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.entities.queue_entities import (
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
@ -124,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
@ -244,7 +246,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
except Exception:
logger.exception(f"Failed to listen audio message, task_id: {task_id}")
break
if tts_publisher:
@ -493,6 +495,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
else:
continue

View File

@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
@ -106,6 +107,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self._wip_workflow_agent_logs = {}
self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -216,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
break
else:
yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
except Exception as e:
except Exception:
logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
break
if tts_publisher:
@ -387,6 +389,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield self._text_chunk_to_stream_response(
delta_text, from_variable_selector=event.from_variable_selector
)
elif isinstance(event, QueueAgentLogEvent):
yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
else:
continue

View File

@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -23,6 +24,7 @@ from core.app.entities.queue_entities import (
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
AgentLogEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
@ -295,6 +297,17 @@ class WorkflowBasedAppRunner(AppRunner):
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, AgentLogEvent):
self._publish_event(
QueueAgentLogEvent(
id=event.id,
node_execution_id=event.node_execution_id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(

View File

@ -1,6 +1,6 @@
from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from typing import Any, Mapping, Optional
from pydantic import BaseModel
@ -38,6 +38,7 @@ class QueueEvent(StrEnum):
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
AGENT_LOG = "agent_log"
ERROR = "error"
PING = "ping"
STOP = "stop"
@ -300,6 +301,20 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueAgentLogEvent(AppQueueEvent):
"""
QueueAgentLogEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_LOG
id: str
node_execution_id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity

View File

@ -59,6 +59,7 @@ class StreamEvent(Enum):
ITERATION_COMPLETED = "iteration_completed"
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
class StreamResponse(BaseModel):
@ -625,3 +626,24 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class AgentLogStreamResponse(StreamResponse):
"""
AgentLogStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
node_execution_id: str
id: str
parent_id: str | None
error: str | None
status: str
data: Mapping[str, Any]
event: StreamEvent = StreamEvent.AGENT_LOG
data: Data

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@ -21,6 +22,7 @@ from core.app.entities.queue_entities import (
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@ -63,6 +65,7 @@ class WorkflowCycleManage:
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (
@ -283,9 +286,16 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
execution_metadata_dict = event.execution_metadata
if self._wip_workflow_agent_logs.get(event.node_execution_id):
if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
event.node_execution_id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
@ -332,9 +342,16 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
execution_metadata_dict = event.execution_metadata
if self._wip_workflow_agent_logs.get(event.node_execution_id):
if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
event.node_execution_id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
@ -746,3 +763,52 @@ class WorkflowCycleManage:
raise Exception(f"Workflow node execution not found: {node_execution_id}")
return workflow_node_execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
node_execution = self._wip_workflow_node_executions.get(event.node_execution_id)
if not node_execution:
raise Exception(f"Workflow node execution not found: {event.node_execution_id}")
node_execution_id = node_execution.id
original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, [])
# try to find the log with the same id
for log in original_agent_logs:
if log.id == event.id:
# update the log
log.status = event.status
log.error = event.error
log.data = event.data
break
else:
# append the log
original_agent_logs.append(
AgentLogStreamResponse.Data(
id=event.id,
parent_id=event.parent_id,
node_execution_id=node_execution_id,
error=event.error,
status=event.status,
data=event.data,
)
)
self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=node_execution_id,
id=event.id,
parent_id=event.parent_id,
error=event.error,
status=event.status,
data=event.data,
),
)

View File

@ -1,7 +1,7 @@
import base64
import enum
from enum import Enum
from typing import Any, Optional, Union
from typing import Any, Mapping, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
@ -150,6 +150,18 @@ class ToolInvokeMessage(BaseModel):
raise ValueError(f"The variable name '{value}' is reserved.")
return value
class LogMessage(BaseModel):
class LogStatus(Enum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str
parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
error: Optional[str] = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
@ -160,12 +172,13 @@ class ToolInvokeMessage(BaseModel):
BINARY_LINK = "binary_link"
VARIABLE = "variable"
FILE = "file"
LOG = "log"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
meta: dict[str, Any] | None = None
@field_validator("message", mode="before")

View File

@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
PARALLEL_ID = "parallel_id"

View File

@ -170,4 +170,22 @@ class IterationRunFailedEvent(BaseIterationEvent):
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
###########################################
# Agent Events
###########################################
class BaseAgentEvent(GraphEngineEvent):
pass
class AgentLogEvent(BaseAgentEvent):
id: str = Field(..., description="id")
node_execution_id: str = Field(..., description="node execution id")
parent_id: str | None = Field(..., description="parent id")
error: str | None = Field(..., description="error")
status: str = Field(..., description="status")
data: Mapping[str, Any] = Field(..., description="data")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent

View File

@ -152,7 +152,7 @@ class GraphEngine:
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (
item.route_node_state.node_run_result.outputs
dict(item.route_node_state.node_run_result.outputs)
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {}

View File

@ -16,6 +16,7 @@ from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
@ -55,6 +56,17 @@ class ToolNode(BaseNode[ToolNodeData]):
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
yield AgentLogEvent(
id=self.node_id,
node_execution_id=self.id,
parent_id=None,
error=None,
status="running",
data={
"tool_info": tool_info,
},
)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(