mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 07:29:02 +08:00
Feat: continue on error (#11458)
Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local> Co-authored-by: Novice Lee <novicelee@NoviPro.local>
This commit is contained in:
parent
bec5451f12
commit
79a710ce98
@ -19,6 +19,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
QueueMessageReplaceEvent,
|
QueueMessageReplaceEvent,
|
||||||
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
@ -31,6 +32,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
@ -317,7 +319,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
|
|
||||||
if response:
|
if response:
|
||||||
yield response
|
yield response
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||||
|
|
||||||
response = self._workflow_node_finish_to_stream_response(
|
response = self._workflow_node_finish_to_stream_response(
|
||||||
@ -384,6 +386,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||||
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||||
|
if not workflow_run:
|
||||||
|
raise Exception("Workflow run not initialized.")
|
||||||
|
|
||||||
|
if not graph_runtime_state:
|
||||||
|
raise Exception("Graph runtime state not initialized.")
|
||||||
|
|
||||||
|
workflow_run = self._handle_workflow_run_partial_success(
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
start_at=graph_runtime_state.start_at,
|
||||||
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
|
outputs=event.outputs,
|
||||||
|
exceptions_count=event.exceptions_count,
|
||||||
|
conversation_id=None,
|
||||||
|
trace_manager=trace_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
|
||||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
@ -401,6 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
error=event.error,
|
error=event.error,
|
||||||
conversation_id=self._conversation.id,
|
conversation_id=self._conversation.id,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
|
exceptions_count=event.exceptions_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._workflow_finish_to_stream_response(
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueMessageEndEvent,
|
QueueMessageEndEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
WorkflowQueueMessage,
|
WorkflowQueueMessage,
|
||||||
)
|
)
|
||||||
@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager):
|
|||||||
| QueueErrorEvent
|
| QueueErrorEvent
|
||||||
| QueueMessageEndEvent
|
| QueueMessageEndEvent
|
||||||
| QueueWorkflowSucceededEvent
|
| QueueWorkflowSucceededEvent
|
||||||
| QueueWorkflowFailedEvent,
|
| QueueWorkflowFailedEvent
|
||||||
|
| QueueWorkflowPartialSuccessEvent,
|
||||||
):
|
):
|
||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
@ -26,6 +27,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
@ -276,7 +278,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
|
|
||||||
if response:
|
if response:
|
||||||
yield response
|
yield response
|
||||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||||
|
|
||||||
response = self._workflow_node_finish_to_stream_response(
|
response = self._workflow_node_finish_to_stream_response(
|
||||||
@ -345,22 +347,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
yield self._workflow_finish_to_stream_response(
|
yield self._workflow_finish_to_stream_response(
|
||||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise Exception("Workflow run not initialized.")
|
raise Exception("Workflow run not initialized.")
|
||||||
|
|
||||||
if not graph_runtime_state:
|
if not graph_runtime_state:
|
||||||
raise Exception("Graph runtime state not initialized.")
|
raise Exception("Graph runtime state not initialized.")
|
||||||
|
|
||||||
workflow_run = self._handle_workflow_run_failed(
|
workflow_run = self._handle_workflow_run_partial_success(
|
||||||
workflow_run=workflow_run,
|
workflow_run=workflow_run,
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
status=WorkflowRunStatus.FAILED
|
outputs=event.outputs,
|
||||||
if isinstance(event, QueueWorkflowFailedEvent)
|
exceptions_count=event.exceptions_count,
|
||||||
else WorkflowRunStatus.STOPPED,
|
|
||||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
@ -368,6 +368,60 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
# save workflow app log
|
# save workflow app log
|
||||||
self._save_workflow_app_log(workflow_run)
|
self._save_workflow_app_log(workflow_run)
|
||||||
|
|
||||||
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||||
|
if not workflow_run:
|
||||||
|
raise Exception("Workflow run not initialized.")
|
||||||
|
|
||||||
|
if not graph_runtime_state:
|
||||||
|
raise Exception("Graph runtime state not initialized.")
|
||||||
|
handle_args = {
|
||||||
|
"workflow_run": workflow_run,
|
||||||
|
"start_at": graph_runtime_state.start_at,
|
||||||
|
"total_tokens": graph_runtime_state.total_tokens,
|
||||||
|
"total_steps": graph_runtime_state.node_run_steps,
|
||||||
|
"status": WorkflowRunStatus.FAILED
|
||||||
|
if isinstance(event, QueueWorkflowFailedEvent)
|
||||||
|
else WorkflowRunStatus.STOPPED,
|
||||||
|
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||||
|
"conversation_id": None,
|
||||||
|
"trace_manager": trace_manager,
|
||||||
|
"exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
|
||||||
|
}
|
||||||
|
workflow_run = self._handle_workflow_run_failed(**handle_args)
|
||||||
|
|
||||||
|
# save workflow app log
|
||||||
|
self._save_workflow_app_log(workflow_run)
|
||||||
|
|
||||||
|
yield self._workflow_finish_to_stream_response(
|
||||||
|
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
|
)
|
||||||
|
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
|
||||||
|
if not workflow_run:
|
||||||
|
raise Exception("Workflow run not initialized.")
|
||||||
|
|
||||||
|
if not graph_runtime_state:
|
||||||
|
raise Exception("Graph runtime state not initialized.")
|
||||||
|
handle_args = {
|
||||||
|
"workflow_run": workflow_run,
|
||||||
|
"start_at": graph_runtime_state.start_at,
|
||||||
|
"total_tokens": graph_runtime_state.total_tokens,
|
||||||
|
"total_steps": graph_runtime_state.node_run_steps,
|
||||||
|
"status": WorkflowRunStatus.FAILED
|
||||||
|
if isinstance(event, QueueWorkflowFailedEvent)
|
||||||
|
else WorkflowRunStatus.STOPPED,
|
||||||
|
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||||
|
"conversation_id": None,
|
||||||
|
"trace_manager": trace_manager,
|
||||||
|
"exceptions_count": event.exceptions_count,
|
||||||
|
}
|
||||||
|
workflow_run = self._handle_workflow_run_partial_success(**handle_args)
|
||||||
|
|
||||||
|
# save workflow app log
|
||||||
|
self._save_workflow_app_log(workflow_run)
|
||||||
|
|
||||||
yield self._workflow_finish_to_stream_response(
|
yield self._workflow_finish_to_stream_response(
|
||||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
@ -18,6 +19,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueRetrieverResourcesEvent,
|
QueueRetrieverResourcesEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
@ -25,6 +27,7 @@ from core.workflow.entities.variable_pool import VariablePool
|
|||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
IterationRunFailedEvent,
|
IterationRunFailedEvent,
|
||||||
@ -32,6 +35,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
IterationRunStartedEvent,
|
IterationRunStartedEvent,
|
||||||
IterationRunSucceededEvent,
|
IterationRunSucceededEvent,
|
||||||
NodeInIterationFailedEvent,
|
NodeInIterationFailedEvent,
|
||||||
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
@ -176,8 +180,12 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
)
|
)
|
||||||
elif isinstance(event, GraphRunSucceededEvent):
|
elif isinstance(event, GraphRunSucceededEvent):
|
||||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||||
|
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||||
|
self._publish_event(
|
||||||
|
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
|
||||||
|
)
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeStartedEvent(
|
QueueNodeStartedEvent(
|
||||||
@ -253,6 +261,36 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
in_iteration_id=event.in_iteration_id,
|
in_iteration_id=event.in_iteration_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(event, NodeRunExceptionEvent):
|
||||||
|
self._publish_event(
|
||||||
|
QueueNodeExceptionEvent(
|
||||||
|
node_execution_id=event.id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
parent_parallel_id=event.parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||||
|
start_at=event.route_node_state.start_at,
|
||||||
|
inputs=event.route_node_state.node_run_result.inputs
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
process_data=event.route_node_state.node_run_result.process_data
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
outputs=event.route_node_state.node_run_result.outputs
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
error=event.route_node_state.node_run_result.error
|
||||||
|
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||||
|
else "Unknown error",
|
||||||
|
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
in_iteration_id=event.in_iteration_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif isinstance(event, NodeInIterationFailedEvent):
|
elif isinstance(event, NodeInIterationFailedEvent):
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueNodeInIterationFailedEvent(
|
QueueNodeInIterationFailedEvent(
|
||||||
|
@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
|
|||||||
WORKFLOW_STARTED = "workflow_started"
|
WORKFLOW_STARTED = "workflow_started"
|
||||||
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
WORKFLOW_SUCCEEDED = "workflow_succeeded"
|
||||||
WORKFLOW_FAILED = "workflow_failed"
|
WORKFLOW_FAILED = "workflow_failed"
|
||||||
|
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
|
||||||
ITERATION_START = "iteration_start"
|
ITERATION_START = "iteration_start"
|
||||||
ITERATION_NEXT = "iteration_next"
|
ITERATION_NEXT = "iteration_next"
|
||||||
ITERATION_COMPLETED = "iteration_completed"
|
ITERATION_COMPLETED = "iteration_completed"
|
||||||
NODE_STARTED = "node_started"
|
NODE_STARTED = "node_started"
|
||||||
NODE_SUCCEEDED = "node_succeeded"
|
NODE_SUCCEEDED = "node_succeeded"
|
||||||
NODE_FAILED = "node_failed"
|
NODE_FAILED = "node_failed"
|
||||||
|
NODE_EXCEPTION = "node_exception"
|
||||||
RETRIEVER_RESOURCES = "retriever_resources"
|
RETRIEVER_RESOURCES = "retriever_resources"
|
||||||
ANNOTATION_REPLY = "annotation_reply"
|
ANNOTATION_REPLY = "annotation_reply"
|
||||||
AGENT_THOUGHT = "agent_thought"
|
AGENT_THOUGHT = "agent_thought"
|
||||||
@ -237,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
|
|||||||
|
|
||||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||||
error: str
|
error: str
|
||||||
|
exceptions_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueWorkflowFailedEvent entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
|
||||||
|
exceptions_count: int
|
||||||
|
outputs: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeStartedEvent(AppQueueEvent):
|
class QueueNodeStartedEvent(AppQueueEvent):
|
||||||
@ -331,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
|||||||
error: str
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueueNodeExceptionEvent(AppQueueEvent):
|
||||||
|
"""
|
||||||
|
QueueNodeExceptionEvent entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
event: QueueEvent = QueueEvent.NODE_EXCEPTION
|
||||||
|
|
||||||
|
node_execution_id: str
|
||||||
|
node_id: str
|
||||||
|
node_type: NodeType
|
||||||
|
node_data: BaseNodeData
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
parent_parallel_id: Optional[str] = None
|
||||||
|
"""parent parallel id if node is in parallel"""
|
||||||
|
parent_parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parent parallel start node id if node is in parallel"""
|
||||||
|
in_iteration_id: Optional[str] = None
|
||||||
|
"""iteration id if node is in iteration"""
|
||||||
|
start_at: datetime
|
||||||
|
|
||||||
|
inputs: Optional[dict[str, Any]] = None
|
||||||
|
process_data: Optional[dict[str, Any]] = None
|
||||||
|
outputs: Optional[dict[str, Any]] = None
|
||||||
|
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||||
|
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeFailedEvent(AppQueueEvent):
|
class QueueNodeFailedEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueNodeFailedEvent entity
|
QueueNodeFailedEvent entity
|
||||||
|
@ -213,6 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||||||
created_by: Optional[dict] = None
|
created_by: Optional[dict] = None
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int
|
finished_at: int
|
||||||
|
exceptions_count: Optional[int] = 0
|
||||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||||
|
@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
@ -164,6 +165,55 @@ class WorkflowCycleManage:
|
|||||||
|
|
||||||
return workflow_run
|
return workflow_run
|
||||||
|
|
||||||
|
def _handle_workflow_run_partial_success(
|
||||||
|
self,
|
||||||
|
workflow_run: WorkflowRun,
|
||||||
|
start_at: float,
|
||||||
|
total_tokens: int,
|
||||||
|
total_steps: int,
|
||||||
|
outputs: Mapping[str, Any] | None = None,
|
||||||
|
exceptions_count: int = 0,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
|
) -> WorkflowRun:
|
||||||
|
"""
|
||||||
|
Workflow run success
|
||||||
|
:param workflow_run: workflow run
|
||||||
|
:param start_at: start time
|
||||||
|
:param total_tokens: total tokens
|
||||||
|
:param total_steps: total steps
|
||||||
|
:param outputs: outputs
|
||||||
|
:param conversation_id: conversation id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||||
|
|
||||||
|
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||||
|
|
||||||
|
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||||
|
workflow_run.outputs = json.dumps(outputs or {})
|
||||||
|
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||||
|
workflow_run.total_tokens = total_tokens
|
||||||
|
workflow_run.total_steps = total_steps
|
||||||
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
workflow_run.exceptions_count = exceptions_count
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(workflow_run)
|
||||||
|
|
||||||
|
if trace_manager:
|
||||||
|
trace_manager.add_trace_task(
|
||||||
|
TraceTask(
|
||||||
|
TraceTaskName.WORKFLOW_TRACE,
|
||||||
|
workflow_run=workflow_run,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_id=trace_manager.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
return workflow_run
|
||||||
|
|
||||||
def _handle_workflow_run_failed(
|
def _handle_workflow_run_failed(
|
||||||
self,
|
self,
|
||||||
workflow_run: WorkflowRun,
|
workflow_run: WorkflowRun,
|
||||||
@ -174,6 +224,7 @@ class WorkflowCycleManage:
|
|||||||
error: str,
|
error: str,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
trace_manager: Optional[TraceQueueManager] = None,
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
|
exceptions_count: int = 0,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
"""
|
"""
|
||||||
Workflow run failed
|
Workflow run failed
|
||||||
@ -193,7 +244,7 @@ class WorkflowCycleManage:
|
|||||||
workflow_run.total_tokens = total_tokens
|
workflow_run.total_tokens = total_tokens
|
||||||
workflow_run.total_steps = total_steps
|
workflow_run.total_steps = total_steps
|
||||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
workflow_run.exceptions_count = exceptions_count
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
running_workflow_node_executions = (
|
running_workflow_node_executions = (
|
||||||
@ -318,7 +369,7 @@ class WorkflowCycleManage:
|
|||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
def _handle_workflow_node_execution_failed(
|
def _handle_workflow_node_execution_failed(
|
||||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
|
||||||
) -> WorkflowNodeExecution:
|
) -> WorkflowNodeExecution:
|
||||||
"""
|
"""
|
||||||
Workflow node execution failed
|
Workflow node execution failed
|
||||||
@ -337,7 +388,11 @@ class WorkflowCycleManage:
|
|||||||
)
|
)
|
||||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||||
{
|
{
|
||||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
WorkflowNodeExecution.status: (
|
||||||
|
WorkflowNodeExecutionStatus.FAILED.value
|
||||||
|
if not isinstance(event, QueueNodeExceptionEvent)
|
||||||
|
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||||
|
),
|
||||||
WorkflowNodeExecution.error: event.error,
|
WorkflowNodeExecution.error: event.error,
|
||||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
|
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
|
||||||
@ -351,8 +406,11 @@ class WorkflowCycleManage:
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
workflow_node_execution.status = (
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
WorkflowNodeExecutionStatus.FAILED.value
|
||||||
|
if not isinstance(event, QueueNodeExceptionEvent)
|
||||||
|
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||||
|
)
|
||||||
workflow_node_execution.error = event.error
|
workflow_node_execution.error = event.error
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||||
@ -433,6 +491,7 @@ class WorkflowCycleManage:
|
|||||||
created_at=int(workflow_run.created_at.timestamp()),
|
created_at=int(workflow_run.created_at.timestamp()),
|
||||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
|
||||||
|
exceptions_count=workflow_run.exceptions_count,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -483,7 +542,10 @@ class WorkflowCycleManage:
|
|||||||
|
|
||||||
def _workflow_node_finish_to_stream_response(
|
def _workflow_node_finish_to_stream_response(
|
||||||
self,
|
self,
|
||||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
event: QueueNodeSucceededEvent
|
||||||
|
| QueueNodeFailedEvent
|
||||||
|
| QueueNodeInIterationFailedEvent
|
||||||
|
| QueueNodeExceptionEvent,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_node_execution: WorkflowNodeExecution,
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
) -> Optional[NodeFinishStreamResponse]:
|
) -> Optional[NodeFinishStreamResponse]:
|
||||||
|
@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5
|
|||||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||||
|
|
||||||
|
|
||||||
|
class MaxRetriesExceededError(Exception):
|
||||||
|
"""Raised when the maximum number of retries is exceeded."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
if "allow_redirects" in kwargs:
|
if "allow_redirects" in kwargs:
|
||||||
allow_redirects = kwargs.pop("allow_redirects")
|
allow_redirects = kwargs.pop("allow_redirects")
|
||||||
@ -64,7 +70,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||||||
if retries <= max_retries:
|
if retries <= max_retries:
|
||||||
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
|
||||||
|
|
||||||
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
|
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
|
||||||
|
|
||||||
|
|
||||||
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
|
@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
IterationRunFailedEvent,
|
IterationRunFailedEvent,
|
||||||
@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||||
elif isinstance(event, GraphRunSucceededEvent):
|
elif isinstance(event, GraphRunSucceededEvent):
|
||||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||||
|
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||||
|
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
|
||||||
elif isinstance(event, GraphRunFailedEvent):
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||||
elif isinstance(event, NodeRunStartedEvent):
|
elif isinstance(event, NodeRunStartedEvent):
|
||||||
|
@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
|
|||||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||||
|
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||||
|
|
||||||
|
|
||||||
class NodeRunResult(BaseModel):
|
class NodeRunResult(BaseModel):
|
||||||
@ -43,3 +44,4 @@ class NodeRunResult(BaseModel):
|
|||||||
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
|
||||||
|
|
||||||
error: Optional[str] = None # error message if status is failed
|
error: Optional[str] = None # error message if status is failed
|
||||||
|
error_type: Optional[str] = None # error type if status is failed
|
||||||
|
@ -33,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent):
|
|||||||
|
|
||||||
class GraphRunFailedEvent(BaseGraphEvent):
|
class GraphRunFailedEvent(BaseGraphEvent):
|
||||||
error: str = Field(..., description="failed reason")
|
error: str = Field(..., description="failed reason")
|
||||||
|
exceptions_count: Optional[int] = Field(description="exception count", default=0)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRunPartialSucceededEvent(BaseGraphEvent):
|
||||||
|
exceptions_count: int = Field(..., description="exception count")
|
||||||
|
outputs: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
@ -83,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
|||||||
error: str = Field(..., description="error")
|
error: str = Field(..., description="error")
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRunExceptionEvent(BaseNodeEvent):
|
||||||
|
error: str = Field(..., description="error")
|
||||||
|
|
||||||
|
|
||||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||||
error: str = Field(..., description="error")
|
error: str = Field(..., description="error")
|
||||||
|
|
||||||
|
@ -64,13 +64,21 @@ class Graph(BaseModel):
|
|||||||
edge_configs = graph_config.get("edges")
|
edge_configs = graph_config.get("edges")
|
||||||
if edge_configs is None:
|
if edge_configs is None:
|
||||||
edge_configs = []
|
edge_configs = []
|
||||||
|
# node configs
|
||||||
|
node_configs = graph_config.get("nodes")
|
||||||
|
if not node_configs:
|
||||||
|
raise ValueError("Graph must have at least one node")
|
||||||
|
|
||||||
edge_configs = cast(list, edge_configs)
|
edge_configs = cast(list, edge_configs)
|
||||||
|
node_configs = cast(list, node_configs)
|
||||||
|
|
||||||
# reorganize edges mapping
|
# reorganize edges mapping
|
||||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||||
target_edge_ids = set()
|
target_edge_ids = set()
|
||||||
|
fail_branch_source_node_id = [
|
||||||
|
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
|
||||||
|
]
|
||||||
for edge_config in edge_configs:
|
for edge_config in edge_configs:
|
||||||
source_node_id = edge_config.get("source")
|
source_node_id = edge_config.get("source")
|
||||||
if not source_node_id:
|
if not source_node_id:
|
||||||
@ -90,8 +98,16 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# parse run condition
|
# parse run condition
|
||||||
run_condition = None
|
run_condition = None
|
||||||
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
|
if edge_config.get("sourceHandle"):
|
||||||
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
|
if (
|
||||||
|
edge_config.get("source") in fail_branch_source_node_id
|
||||||
|
and edge_config.get("sourceHandle") != "fail-branch"
|
||||||
|
):
|
||||||
|
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
|
||||||
|
elif edge_config.get("sourceHandle") != "source":
|
||||||
|
run_condition = RunCondition(
|
||||||
|
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
|
||||||
|
)
|
||||||
|
|
||||||
graph_edge = GraphEdge(
|
graph_edge = GraphEdge(
|
||||||
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
|
||||||
@ -100,13 +116,6 @@ class Graph(BaseModel):
|
|||||||
edge_mapping[source_node_id].append(graph_edge)
|
edge_mapping[source_node_id].append(graph_edge)
|
||||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||||
|
|
||||||
# node configs
|
|
||||||
node_configs = graph_config.get("nodes")
|
|
||||||
if not node_configs:
|
|
||||||
raise ValueError("Graph must have at least one node")
|
|
||||||
|
|
||||||
node_configs = cast(list, node_configs)
|
|
||||||
|
|
||||||
# fetch nodes that have no predecessor node
|
# fetch nodes that have no predecessor node
|
||||||
root_node_configs = []
|
root_node_configs = []
|
||||||
all_node_id_config_mapping: dict[str, dict] = {}
|
all_node_id_config_mapping: dict[str, dict] = {}
|
||||||
|
@ -15,6 +15,7 @@ class RouteNodeState(BaseModel):
|
|||||||
SUCCESS = "success"
|
SUCCESS = "success"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
PAUSED = "paused"
|
PAUSED = "paused"
|
||||||
|
EXCEPTION = "exception"
|
||||||
|
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
"""node state id"""
|
"""node state id"""
|
||||||
@ -51,7 +52,11 @@ class RouteNodeState(BaseModel):
|
|||||||
|
|
||||||
:param run_result: run result
|
:param run_result: run result
|
||||||
"""
|
"""
|
||||||
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
|
if self.status in {
|
||||||
|
RouteNodeState.Status.SUCCESS,
|
||||||
|
RouteNodeState.Status.FAILED,
|
||||||
|
RouteNodeState.Status.EXCEPTION,
|
||||||
|
}:
|
||||||
raise Exception(f"Route state {self.id} already finished")
|
raise Exception(f"Route state {self.id} already finished")
|
||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
@ -59,6 +64,9 @@ class RouteNodeState(BaseModel):
|
|||||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
self.status = RouteNodeState.Status.FAILED
|
self.status = RouteNodeState.Status.FAILED
|
||||||
self.failed_reason = run_result.error
|
self.failed_reason = run_result.error
|
||||||
|
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||||
|
self.status = RouteNodeState.Status.EXCEPTION
|
||||||
|
self.failed_reason = run_result.error
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid route status {run_result.status}")
|
raise Exception(f"Invalid route status {run_result.status}")
|
||||||
|
|
||||||
|
@ -5,21 +5,23 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from concurrent.futures import ThreadPoolExecutor, wait
|
from concurrent.futures import ThreadPoolExecutor, wait
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
BaseIterationEvent,
|
BaseIterationEvent,
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
|
|||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -128,6 +132,7 @@ class GraphEngine:
|
|||||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||||
# trigger graph run start event
|
# trigger graph run start event
|
||||||
yield GraphRunStartedEvent()
|
yield GraphRunStartedEvent()
|
||||||
|
handle_exceptions = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||||
@ -140,13 +145,17 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# run graph
|
# run graph
|
||||||
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
|
generator = stream_processor.process(
|
||||||
|
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
|
||||||
|
)
|
||||||
for item in generator:
|
for item in generator:
|
||||||
try:
|
try:
|
||||||
yield item
|
yield item
|
||||||
if isinstance(item, NodeRunFailedEvent):
|
if isinstance(item, NodeRunFailedEvent):
|
||||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
|
yield GraphRunFailedEvent(
|
||||||
|
error=item.route_node_state.failed_reason or "Unknown error.",
|
||||||
|
exceptions_count=len(handle_exceptions),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
elif isinstance(item, NodeRunSucceededEvent):
|
elif isinstance(item, NodeRunSucceededEvent):
|
||||||
if item.node_type == NodeType.END:
|
if item.node_type == NodeType.END:
|
||||||
@ -172,19 +181,24 @@ class GraphEngine:
|
|||||||
].strip()
|
].strip()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Graph run failed")
|
logger.exception("Graph run failed")
|
||||||
yield GraphRunFailedEvent(error=str(e))
|
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
|
||||||
return
|
return
|
||||||
|
# count exceptions to determine partial success
|
||||||
|
if len(handle_exceptions) > 0:
|
||||||
|
yield GraphRunPartialSucceededEvent(
|
||||||
|
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
|
||||||
|
)
|
||||||
|
else:
|
||||||
# trigger graph run success event
|
# trigger graph run success event
|
||||||
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
||||||
self._release_thread()
|
self._release_thread()
|
||||||
except GraphRunFailedError as e:
|
except GraphRunFailedError as e:
|
||||||
yield GraphRunFailedEvent(error=e.error)
|
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
|
||||||
self._release_thread()
|
self._release_thread()
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Unknown Error when graph running")
|
logger.exception("Unknown Error when graph running")
|
||||||
yield GraphRunFailedEvent(error=str(e))
|
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
|
||||||
self._release_thread()
|
self._release_thread()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -198,6 +212,7 @@ class GraphEngine:
|
|||||||
in_parallel_id: Optional[str] = None,
|
in_parallel_id: Optional[str] = None,
|
||||||
parent_parallel_id: Optional[str] = None,
|
parent_parallel_id: Optional[str] = None,
|
||||||
parent_parallel_start_node_id: Optional[str] = None,
|
parent_parallel_start_node_id: Optional[str] = None,
|
||||||
|
handle_exceptions: list[str] = [],
|
||||||
) -> Generator[GraphEngineEvent, None, None]:
|
) -> Generator[GraphEngineEvent, None, None]:
|
||||||
parallel_start_node_id = None
|
parallel_start_node_id = None
|
||||||
if in_parallel_id:
|
if in_parallel_id:
|
||||||
@ -242,7 +257,7 @@ class GraphEngine:
|
|||||||
previous_node_id=previous_node_id,
|
previous_node_id=previous_node_id,
|
||||||
thread_pool_id=self.thread_pool_id,
|
thread_pool_id=self.thread_pool_id,
|
||||||
)
|
)
|
||||||
|
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||||
try:
|
try:
|
||||||
# run node
|
# run node
|
||||||
generator = self._run_node(
|
generator = self._run_node(
|
||||||
@ -252,6 +267,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
handle_exceptions=handle_exceptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
for item in generator:
|
for item in generator:
|
||||||
@ -301,7 +317,12 @@ class GraphEngine:
|
|||||||
|
|
||||||
if len(edge_mappings) == 1:
|
if len(edge_mappings) == 1:
|
||||||
edge = edge_mappings[0]
|
edge = edge_mappings[0]
|
||||||
|
if (
|
||||||
|
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||||
|
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||||
|
and edge.run_condition is None
|
||||||
|
):
|
||||||
|
break
|
||||||
if edge.run_condition:
|
if edge.run_condition:
|
||||||
result = ConditionManager.get_condition_handler(
|
result = ConditionManager.get_condition_handler(
|
||||||
init_params=self.init_params,
|
init_params=self.init_params,
|
||||||
@ -334,7 +355,7 @@ class GraphEngine:
|
|||||||
if len(sub_edge_mappings) == 0:
|
if len(sub_edge_mappings) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
edge = sub_edge_mappings[0]
|
edge = cast(GraphEdge, sub_edge_mappings[0])
|
||||||
|
|
||||||
result = ConditionManager.get_condition_handler(
|
result = ConditionManager.get_condition_handler(
|
||||||
init_params=self.init_params,
|
init_params=self.init_params,
|
||||||
@ -355,6 +376,7 @@ class GraphEngine:
|
|||||||
edge_mappings=sub_edge_mappings,
|
edge_mappings=sub_edge_mappings,
|
||||||
in_parallel_id=in_parallel_id,
|
in_parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
handle_exceptions=handle_exceptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
for item in parallel_generator:
|
for item in parallel_generator:
|
||||||
@ -369,11 +391,18 @@ class GraphEngine:
|
|||||||
break
|
break
|
||||||
|
|
||||||
next_node_id = final_node_id
|
next_node_id = final_node_id
|
||||||
|
elif (
|
||||||
|
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
|
||||||
|
and node_instance.should_continue_on_error
|
||||||
|
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
|
||||||
|
):
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
parallel_generator = self._run_parallel_branches(
|
parallel_generator = self._run_parallel_branches(
|
||||||
edge_mappings=edge_mappings,
|
edge_mappings=edge_mappings,
|
||||||
in_parallel_id=in_parallel_id,
|
in_parallel_id=in_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
handle_exceptions=handle_exceptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
for item in parallel_generator:
|
for item in parallel_generator:
|
||||||
@ -395,6 +424,7 @@ class GraphEngine:
|
|||||||
edge_mappings: list[GraphEdge],
|
edge_mappings: list[GraphEdge],
|
||||||
in_parallel_id: Optional[str] = None,
|
in_parallel_id: Optional[str] = None,
|
||||||
parallel_start_node_id: Optional[str] = None,
|
parallel_start_node_id: Optional[str] = None,
|
||||||
|
handle_exceptions: list[str] = [],
|
||||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||||
# if nodes has no run conditions, parallel run all nodes
|
# if nodes has no run conditions, parallel run all nodes
|
||||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||||
@ -438,6 +468,7 @@ class GraphEngine:
|
|||||||
"parallel_start_node_id": edge.target_node_id,
|
"parallel_start_node_id": edge.target_node_id,
|
||||||
"parent_parallel_id": in_parallel_id,
|
"parent_parallel_id": in_parallel_id,
|
||||||
"parent_parallel_start_node_id": parallel_start_node_id,
|
"parent_parallel_start_node_id": parallel_start_node_id,
|
||||||
|
"handle_exceptions": handle_exceptions,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -481,6 +512,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id: str,
|
parallel_start_node_id: str,
|
||||||
parent_parallel_id: Optional[str] = None,
|
parent_parallel_id: Optional[str] = None,
|
||||||
parent_parallel_start_node_id: Optional[str] = None,
|
parent_parallel_start_node_id: Optional[str] = None,
|
||||||
|
handle_exceptions: list[str] = [],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Run parallel nodes
|
Run parallel nodes
|
||||||
@ -502,6 +534,7 @@ class GraphEngine:
|
|||||||
in_parallel_id=parallel_id,
|
in_parallel_id=parallel_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
handle_exceptions=handle_exceptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
for item in generator:
|
for item in generator:
|
||||||
@ -548,6 +581,7 @@ class GraphEngine:
|
|||||||
parallel_start_node_id: Optional[str] = None,
|
parallel_start_node_id: Optional[str] = None,
|
||||||
parent_parallel_id: Optional[str] = None,
|
parent_parallel_id: Optional[str] = None,
|
||||||
parent_parallel_start_node_id: Optional[str] = None,
|
parent_parallel_start_node_id: Optional[str] = None,
|
||||||
|
handle_exceptions: list[str] = [],
|
||||||
) -> Generator[GraphEngineEvent, None, None]:
|
) -> Generator[GraphEngineEvent, None, None]:
|
||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
@ -587,6 +621,37 @@ class GraphEngine:
|
|||||||
route_node_state.set_finished(run_result=run_result)
|
route_node_state.set_finished(run_result=run_result)
|
||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
|
if node_instance.should_continue_on_error:
|
||||||
|
# if run failed, handle error
|
||||||
|
run_result = self._handle_continue_on_error(
|
||||||
|
node_instance,
|
||||||
|
item.run_result,
|
||||||
|
self.graph_runtime_state.variable_pool,
|
||||||
|
handle_exceptions=handle_exceptions,
|
||||||
|
)
|
||||||
|
route_node_state.node_run_result = run_result
|
||||||
|
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||||
|
if run_result.outputs:
|
||||||
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
|
# append variables to variable pool recursively
|
||||||
|
self._append_variables_recursively(
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
variable_key_list=[variable_key],
|
||||||
|
variable_value=variable_value,
|
||||||
|
)
|
||||||
|
yield NodeRunExceptionEvent(
|
||||||
|
error=run_result.error or "System Error",
|
||||||
|
id=node_instance.id,
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
parent_parallel_id=parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
error=route_node_state.failed_reason or "Unknown error.",
|
error=route_node_state.failed_reason or "Unknown error.",
|
||||||
id=node_instance.id,
|
id=node_instance.id,
|
||||||
@ -599,7 +664,12 @@ class GraphEngine:
|
|||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||||
|
node_instance.node_id
|
||||||
|
):
|
||||||
|
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||||
# plus state total_tokens
|
# plus state total_tokens
|
||||||
self.graph_runtime_state.total_tokens += int(
|
self.graph_runtime_state.total_tokens += int(
|
||||||
@ -735,6 +805,56 @@ class GraphEngine:
|
|||||||
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
||||||
return new_instance
|
return new_instance
|
||||||
|
|
||||||
|
def _handle_continue_on_error(
|
||||||
|
self,
|
||||||
|
node_instance: BaseNode[BaseNodeData],
|
||||||
|
error_result: NodeRunResult,
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
handle_exceptions: list[str] = [],
|
||||||
|
) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
handle continue on error when self._should_continue_on_error is True
|
||||||
|
|
||||||
|
|
||||||
|
:param error_result (NodeRunResult): error run result
|
||||||
|
:param variable_pool (VariablePool): variable pool
|
||||||
|
:return: excption run result
|
||||||
|
"""
|
||||||
|
# add error message and error type to variable pool
|
||||||
|
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
|
||||||
|
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
|
||||||
|
# add error message to handle_exceptions
|
||||||
|
handle_exceptions.append(error_result.error)
|
||||||
|
node_error_args = {
|
||||||
|
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
"error": error_result.error,
|
||||||
|
"inputs": error_result.inputs,
|
||||||
|
"metadata": {
|
||||||
|
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||||
|
return NodeRunResult(
|
||||||
|
**node_error_args,
|
||||||
|
outputs={
|
||||||
|
**node_instance.node_data.default_value_dict,
|
||||||
|
"error_message": error_result.error,
|
||||||
|
"error_type": error_result.error_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
|
||||||
|
if self.graph.edge_mapping.get(node_instance.node_id):
|
||||||
|
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
|
||||||
|
return NodeRunResult(
|
||||||
|
**node_error_args,
|
||||||
|
outputs={
|
||||||
|
"error_message": error_result.error,
|
||||||
|
"error_type": error_result.error_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return error_result
|
||||||
|
|
||||||
|
|
||||||
class GraphRunFailedError(Exception):
|
class GraphRunFailedError(Exception):
|
||||||
def __init__(self, error: str):
|
def __init__(self, error: str):
|
||||||
|
@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
|
|||||||
TextGenerateRouteChunk,
|
TextGenerateRouteChunk,
|
||||||
VarGenerateRouteChunk,
|
VarGenerateRouteChunk,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
|
||||||
|
|
||||||
@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter:
|
|||||||
for edge in reverse_edges:
|
for edge in reverse_edges:
|
||||||
source_node_id = edge.source_node_id
|
source_node_id = edge.source_node_id
|
||||||
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
|
||||||
if source_node_type in {
|
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
|
||||||
|
if (
|
||||||
|
source_node_type
|
||||||
|
in {
|
||||||
NodeType.ANSWER,
|
NodeType.ANSWER,
|
||||||
NodeType.IF_ELSE,
|
NodeType.IF_ELSE,
|
||||||
NodeType.QUESTION_CLASSIFIER,
|
NodeType.QUESTION_CLASSIFIER,
|
||||||
NodeType.ITERATION,
|
NodeType.ITERATION,
|
||||||
NodeType.VARIABLE_ASSIGNER,
|
NodeType.VARIABLE_ASSIGNER,
|
||||||
}:
|
}
|
||||||
|
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
|
||||||
|
):
|
||||||
answer_dependencies[answer_node_id].append(source_node_id)
|
answer_dependencies[answer_node_id].append(source_node_id)
|
||||||
else:
|
else:
|
||||||
cls._recursive_fetch_answer_dependencies(
|
cls._recursive_fetch_answer_dependencies(
|
||||||
|
@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File
|
|||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
|
NodeRunExceptionEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
|
|
||||||
for _ in stream_out_answer_node_ids:
|
for _ in stream_out_answer_node_ids:
|
||||||
yield event
|
yield event
|
||||||
elif isinstance(event, NodeRunSucceededEvent):
|
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
|
||||||
yield event
|
yield event
|
||||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||||
# update self.route_position after all stream event finished
|
# update self.route_position after all stream event finished
|
||||||
|
@ -1,14 +1,124 @@
|
|||||||
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Optional
|
from enum import StrEnum
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from core.workflow.nodes.base.exc import DefaultValueTypeError
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultValueType(StrEnum):
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
OBJECT = "object"
|
||||||
|
ARRAY_NUMBER = "array[number]"
|
||||||
|
ARRAY_STRING = "array[string]"
|
||||||
|
ARRAY_OBJECT = "array[object]"
|
||||||
|
ARRAY_FILES = "array[file]"
|
||||||
|
|
||||||
|
|
||||||
|
NumberType = Union[int, float]
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultValue(BaseModel):
|
||||||
|
value: Any
|
||||||
|
type: DefaultValueType
|
||||||
|
key: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_json(value: str) -> Any:
|
||||||
|
"""Unified JSON parsing handler"""
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
||||||
|
"""Unified array type validation"""
|
||||||
|
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_number(value: str) -> float:
|
||||||
|
"""Unified number conversion handler"""
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError:
|
||||||
|
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_value_type(self) -> "DefaultValue":
|
||||||
|
if self.type is None:
|
||||||
|
raise DefaultValueTypeError("type field is required")
|
||||||
|
|
||||||
|
# Type validation configuration
|
||||||
|
type_validators = {
|
||||||
|
DefaultValueType.STRING: {
|
||||||
|
"type": str,
|
||||||
|
"converter": lambda x: x,
|
||||||
|
},
|
||||||
|
DefaultValueType.NUMBER: {
|
||||||
|
"type": NumberType,
|
||||||
|
"converter": self._convert_number,
|
||||||
|
},
|
||||||
|
DefaultValueType.OBJECT: {
|
||||||
|
"type": dict,
|
||||||
|
"converter": self._parse_json,
|
||||||
|
},
|
||||||
|
DefaultValueType.ARRAY_NUMBER: {
|
||||||
|
"type": list,
|
||||||
|
"element_type": NumberType,
|
||||||
|
"converter": self._parse_json,
|
||||||
|
},
|
||||||
|
DefaultValueType.ARRAY_STRING: {
|
||||||
|
"type": list,
|
||||||
|
"element_type": str,
|
||||||
|
"converter": self._parse_json,
|
||||||
|
},
|
||||||
|
DefaultValueType.ARRAY_OBJECT: {
|
||||||
|
"type": list,
|
||||||
|
"element_type": dict,
|
||||||
|
"converter": self._parse_json,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
validator = type_validators.get(self.type)
|
||||||
|
if not validator:
|
||||||
|
if self.type == DefaultValueType.ARRAY_FILES:
|
||||||
|
# Handle files type
|
||||||
|
return self
|
||||||
|
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
|
||||||
|
|
||||||
|
# Handle string input cases
|
||||||
|
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
|
||||||
|
self.value = validator["converter"](self.value)
|
||||||
|
|
||||||
|
# Validate base type
|
||||||
|
if not isinstance(self.value, validator["type"]):
|
||||||
|
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
|
||||||
|
|
||||||
|
# Validate array element types
|
||||||
|
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
|
||||||
|
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class BaseNodeData(ABC, BaseModel):
|
class BaseNodeData(ABC, BaseModel):
|
||||||
title: str
|
title: str
|
||||||
desc: Optional[str] = None
|
desc: Optional[str] = None
|
||||||
|
error_strategy: Optional[ErrorStrategy] = None
|
||||||
|
default_value: Optional[list[DefaultValue]] = None
|
||||||
version: str = "1"
|
version: str = "1"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_value_dict(self):
|
||||||
|
if self.default_value:
|
||||||
|
return {item.key: item.value for item in self.default_value}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class BaseIterationNodeData(BaseNodeData):
|
class BaseIterationNodeData(BaseNodeData):
|
||||||
start_node_id: Optional[str] = None
|
start_node_id: Optional[str] = None
|
||||||
|
10
api/core/workflow/nodes/base/exc.py
Normal file
10
api/core/workflow/nodes/base/exc.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
class BaseNodeError(Exception):
|
||||||
|
"""Base class for node errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultValueTypeError(BaseNodeError):
|
||||||
|
"""Raised when the default value type is invalid."""
|
||||||
|
|
||||||
|
pass
|
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]):
|
|||||||
result = self._run()
|
result = self._run()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Node {self.node_id} failed to run")
|
logger.exception(f"Node {self.node_id} failed to run")
|
||||||
result = NodeRunResult(
|
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
|
||||||
error=str(e),
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(result, NodeRunResult):
|
if isinstance(result, NodeRunResult):
|
||||||
yield RunCompletedEvent(run_result=result)
|
yield RunCompletedEvent(run_result=result)
|
||||||
@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return self._node_type
|
return self._node_type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_continue_on_error(self) -> bool:
|
||||||
|
"""judge if should continue on error
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: if should continue on error
|
||||||
|
"""
|
||||||
|
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||||
|
@ -61,7 +61,9 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||||||
# Transform result
|
# Transform result
|
||||||
result = self._transform_result(result, self.node_data.outputs)
|
result = self._transform_result(result, self.node_data.outputs)
|
||||||
except (CodeExecutionError, CodeNodeError) as e:
|
except (CodeExecutionError, CodeNodeError) as e:
|
||||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||||
|
)
|
||||||
|
|
||||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||||
|
|
||||||
|
@ -22,3 +22,16 @@ class NodeType(StrEnum):
|
|||||||
VARIABLE_ASSIGNER = "assigner"
|
VARIABLE_ASSIGNER = "assigner"
|
||||||
DOCUMENT_EXTRACTOR = "document-extractor"
|
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||||
LIST_OPERATOR = "list-operator"
|
LIST_OPERATOR = "list-operator"
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorStrategy(StrEnum):
|
||||||
|
FAIL_BRANCH = "fail-branch"
|
||||||
|
DEFAULT_VALUE = "default-value"
|
||||||
|
|
||||||
|
|
||||||
|
class FailBranchSourceHandle(StrEnum):
|
||||||
|
FAILED = "fail-branch"
|
||||||
|
SUCCESS = "success-branch"
|
||||||
|
|
||||||
|
|
||||||
|
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||||
|
@ -21,6 +21,7 @@ from .entities import (
|
|||||||
from .exc import (
|
from .exc import (
|
||||||
AuthorizationConfigError,
|
AuthorizationConfigError,
|
||||||
FileFetchError,
|
FileFetchError,
|
||||||
|
HttpRequestNodeError,
|
||||||
InvalidHttpMethodError,
|
InvalidHttpMethodError,
|
||||||
ResponseSizeError,
|
ResponseSizeError,
|
||||||
)
|
)
|
||||||
@ -208,8 +209,10 @@ class Executor:
|
|||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
}
|
}
|
||||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||||
|
try:
|
||||||
response = getattr(ssrf_proxy, self.method)(**request_args)
|
response = getattr(ssrf_proxy, self.method)(**request_args)
|
||||||
|
except ssrf_proxy.MaxRetriesExceededError as e:
|
||||||
|
raise HttpRequestNodeError(str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def invoke(self) -> Response:
|
def invoke(self) -> Response:
|
||||||
|
@ -65,6 +65,21 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
|
|
||||||
response = http_executor.invoke()
|
response = http_executor.invoke()
|
||||||
files = self.extract_files(url=http_executor.url, response=response)
|
files = self.extract_files(url=http_executor.url, response=response)
|
||||||
|
if not response.response.is_success and self.should_continue_on_error:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
outputs={
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"body": response.text if not files else "",
|
||||||
|
"headers": response.headers,
|
||||||
|
"files": files,
|
||||||
|
},
|
||||||
|
process_data={
|
||||||
|
"request": http_executor.to_log(),
|
||||||
|
},
|
||||||
|
error=f"Request failed with status code {response.status_code}",
|
||||||
|
error_type="HTTPResponseCodeError",
|
||||||
|
)
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={
|
outputs={
|
||||||
@ -83,6 +98,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -193,6 +193,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
error=str(e),
|
error=str(e),
|
||||||
inputs=node_inputs,
|
inputs=node_inputs,
|
||||||
process_data=process_data,
|
process_data=process_data,
|
||||||
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
@ -139,7 +139,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
"usage": jsonable_encoder(usage),
|
"usage": jsonable_encoder(usage),
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
}
|
}
|
||||||
outputs = {"class_name": category_name}
|
outputs = {"class_name": category_name, "class_id": category_id}
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
@ -56,6 +56,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||||
},
|
},
|
||||||
error=f"Failed to get tool runtime: {str(e)}",
|
error=f"Failed to get tool runtime: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
# get parameters
|
# get parameters
|
||||||
@ -89,6 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
NodeRunMetadataKey.TOOL_INFO: tool_info,
|
||||||
},
|
},
|
||||||
error=f"Failed to invoke tool: {str(e)}",
|
error=f"Failed to invoke tool: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
|
@ -14,6 +14,7 @@ workflow_run_for_log_fields = {
|
|||||||
"total_steps": fields.Integer,
|
"total_steps": fields.Integer,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
|
"exceptions_count": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_run_for_list_fields = {
|
workflow_run_for_list_fields = {
|
||||||
@ -27,6 +28,7 @@ workflow_run_for_list_fields = {
|
|||||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
|
"exceptions_count": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
advanced_chat_workflow_run_for_list_fields = {
|
advanced_chat_workflow_run_for_list_fields = {
|
||||||
@ -42,6 +44,7 @@ advanced_chat_workflow_run_for_list_fields = {
|
|||||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
|
"exceptions_count": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
advanced_chat_workflow_run_pagination_fields = {
|
advanced_chat_workflow_run_pagination_fields = {
|
||||||
@ -73,6 +76,7 @@ workflow_run_detail_fields = {
|
|||||||
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
|
"exceptions_count": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_run_node_execution_fields = {
|
workflow_run_node_execution_fields = {
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
"""add exceptions_count field to WorkflowRun model
|
||||||
|
|
||||||
|
Revision ID: cf8f4fc45278
|
||||||
|
Revises: 01d6889832f7
|
||||||
|
Create Date: 2024-11-28 05:53:21.576178
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'cf8f4fc45278'
|
||||||
|
down_revision = '01d6889832f7'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('exceptions_count')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -325,6 +325,7 @@ class WorkflowRunStatus(StrEnum):
|
|||||||
SUCCEEDED = "succeeded"
|
SUCCEEDED = "succeeded"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
STOPPED = "stopped"
|
STOPPED = "stopped"
|
||||||
|
PARTIAL_SUCCESSED = "partial-succeeded"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowRunStatus":
|
def value_of(cls, value: str) -> "WorkflowRunStatus":
|
||||||
@ -395,7 +396,7 @@ class WorkflowRun(db.Model):
|
|||||||
version = db.Column(db.String(255), nullable=False)
|
version = db.Column(db.String(255), nullable=False)
|
||||||
graph = db.Column(db.Text)
|
graph = db.Column(db.Text)
|
||||||
inputs = db.Column(db.Text)
|
inputs = db.Column(db.Text)
|
||||||
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped
|
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded
|
||||||
outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
|
outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
|
||||||
error = db.Column(db.Text)
|
error = db.Column(db.Text)
|
||||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||||
@ -405,6 +406,7 @@ class WorkflowRun(db.Model):
|
|||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||||
finished_at = db.Column(db.DateTime)
|
finished_at = db.Column(db.DateTime)
|
||||||
|
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
@ -464,6 +466,7 @@ class WorkflowRun(db.Model):
|
|||||||
"created_by": self.created_by,
|
"created_by": self.created_by,
|
||||||
"created_at": self.created_at,
|
"created_at": self.created_at,
|
||||||
"finished_at": self.finished_at,
|
"finished_at": self.finished_at,
|
||||||
|
"exceptions_count": self.exceptions_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -489,6 +492,7 @@ class WorkflowRun(db.Model):
|
|||||||
created_by=data.get("created_by"),
|
created_by=data.get("created_by"),
|
||||||
created_at=data.get("created_at"),
|
created_at=data.get("created_at"),
|
||||||
finished_at=data.get("finished_at"),
|
finished_at=data.get("finished_at"),
|
||||||
|
exceptions_count=data.get("exceptions_count"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -522,6 +526,7 @@ class WorkflowNodeExecutionStatus(Enum):
|
|||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
SUCCEEDED = "succeeded"
|
SUCCEEDED = "succeeded"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
EXCEPTION = "exception"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
|
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
|
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
@ -11,6 +11,9 @@ from core.variables import Variable
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
from core.workflow.nodes.base.node import BaseNode
|
||||||
|
from core.workflow.nodes.enums import ErrorStrategy
|
||||||
from core.workflow.nodes.event import RunCompletedEvent
|
from core.workflow.nodes.event import RunCompletedEvent
|
||||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
@ -225,7 +228,7 @@ class WorkflowService:
|
|||||||
user_inputs=user_inputs,
|
user_inputs=user_inputs,
|
||||||
user_id=account.id,
|
user_id=account.id,
|
||||||
)
|
)
|
||||||
|
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||||
node_run_result: NodeRunResult | None = None
|
node_run_result: NodeRunResult | None = None
|
||||||
for event in generator:
|
for event in generator:
|
||||||
if isinstance(event, RunCompletedEvent):
|
if isinstance(event, RunCompletedEvent):
|
||||||
@ -237,8 +240,35 @@ class WorkflowService:
|
|||||||
|
|
||||||
if not node_run_result:
|
if not node_run_result:
|
||||||
raise ValueError("Node run failed with no run result")
|
raise ValueError("Node run failed with no run result")
|
||||||
|
# single step debug mode error handling return
|
||||||
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
|
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
||||||
|
node_error_args = {
|
||||||
|
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
"error": node_run_result.error,
|
||||||
|
"inputs": node_run_result.inputs,
|
||||||
|
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||||
|
}
|
||||||
|
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||||
|
node_run_result = NodeRunResult(
|
||||||
|
**node_error_args,
|
||||||
|
outputs={
|
||||||
|
**node_instance.node_data.default_value_dict,
|
||||||
|
"error_message": node_run_result.error,
|
||||||
|
"error_type": node_run_result.error_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
node_run_result = NodeRunResult(
|
||||||
|
**node_error_args,
|
||||||
|
outputs={
|
||||||
|
"error_message": node_run_result.error,
|
||||||
|
"error_type": node_run_result.error_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
run_succeeded = node_run_result.status in (
|
||||||
|
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
)
|
||||||
error = node_run_result.error if not run_succeeded else None
|
error = node_run_result.error if not run_succeeded else None
|
||||||
except WorkflowNodeRunFailedError as e:
|
except WorkflowNodeRunFailedError as e:
|
||||||
node_instance = e.node_instance
|
node_instance = e.node_instance
|
||||||
@ -260,7 +290,6 @@ class WorkflowService:
|
|||||||
workflow_node_execution.created_by = account.id
|
workflow_node_execution.created_by = account.id
|
||||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
if run_succeeded and node_run_result:
|
if run_succeeded and node_run_result:
|
||||||
# create workflow node execution
|
# create workflow node execution
|
||||||
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
|
||||||
@ -277,7 +306,11 @@ class WorkflowService:
|
|||||||
workflow_node_execution.execution_metadata = (
|
workflow_node_execution.execution_metadata = (
|
||||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||||
)
|
)
|
||||||
|
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
|
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
|
||||||
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||||
|
workflow_node_execution.error = node_run_result.error
|
||||||
else:
|
else:
|
||||||
# create workflow node execution
|
# create workflow node execution
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
|
@ -0,0 +1,502 @@
|
|||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.graph_engine.entities.event import (
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
NodeRunExceptionEvent,
|
||||||
|
NodeRunStreamChunkEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
|
|
||||||
|
class ContinueOnErrorTestHelper:
|
||||||
|
@staticmethod
|
||||||
|
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a code node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"outputs": {"result": {"type": "number"}},
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
"title": "code",
|
||||||
|
"variables": [],
|
||||||
|
"code_language": "python3",
|
||||||
|
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_http_node(
|
||||||
|
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
|
||||||
|
):
|
||||||
|
"""Helper method to create a http node configuration"""
|
||||||
|
authorization = (
|
||||||
|
{
|
||||||
|
"type": "api-key",
|
||||||
|
"config": {
|
||||||
|
"type": "basic",
|
||||||
|
"api_key": "ak-xxx",
|
||||||
|
"header": "api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if authorization_success
|
||||||
|
else {
|
||||||
|
"type": "api-key",
|
||||||
|
# missing config field
|
||||||
|
}
|
||||||
|
)
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"title": "http",
|
||||||
|
"desc": "",
|
||||||
|
"method": "get",
|
||||||
|
"url": "http://example.com",
|
||||||
|
"authorization": authorization,
|
||||||
|
"headers": "X-Header:123",
|
||||||
|
"params": "A:b",
|
||||||
|
"body": None,
|
||||||
|
"type": "http-request",
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a http node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
|
"title": "HTTP Request",
|
||||||
|
"desc": "",
|
||||||
|
"variables": [],
|
||||||
|
"method": "get",
|
||||||
|
"url": "https://api.github.com/issues",
|
||||||
|
"authorization": {"type": "no-auth", "config": None},
|
||||||
|
"headers": "",
|
||||||
|
"params": "",
|
||||||
|
"body": {"type": "none", "data": []},
|
||||||
|
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a tool node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"title": "a",
|
||||||
|
"desc": "a",
|
||||||
|
"provider_id": "maths",
|
||||||
|
"provider_type": "builtin",
|
||||||
|
"provider_name": "maths",
|
||||||
|
"tool_name": "eval_expression",
|
||||||
|
"tool_label": "eval_expression",
|
||||||
|
"tool_configurations": {},
|
||||||
|
"tool_parameters": {
|
||||||
|
"expression": {
|
||||||
|
"type": "variable",
|
||||||
|
"value": ["1", "123", "args1"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "tool",
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a llm node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"title": "123",
|
||||||
|
"type": "llm",
|
||||||
|
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||||
|
"prompt_template": [
|
||||||
|
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
|
||||||
|
{"role": "user", "text": "{{#sys.query#}}"},
|
||||||
|
],
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": False},
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||||
|
"""Helper method to create a graph engine instance for testing"""
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
variable_pool = {
|
||||||
|
"system_variables": {
|
||||||
|
SystemVariableKey.QUERY: "clear",
|
||||||
|
SystemVariableKey.FILES: [],
|
||||||
|
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||||
|
SystemVariableKey.USER_ID: "aaa",
|
||||||
|
},
|
||||||
|
"user_inputs": user_inputs or {"uid": "takato"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return GraphEngine(
|
||||||
|
tenant_id="111",
|
||||||
|
app_id="222",
|
||||||
|
workflow_type=WorkflowType.CHAT,
|
||||||
|
workflow_id="333",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="444",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
call_depth=0,
|
||||||
|
graph=graph,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
max_execution_steps=500,
|
||||||
|
max_execution_time=1200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_VALUE_EDGE = [
|
||||||
|
{
|
||||||
|
"id": "start-source-node-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "node",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-source-answer-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "answer",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
FAIL_BRANCH_EDGES = [
|
||||||
|
{
|
||||||
|
"id": "start-source-node-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "node",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-true-success-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "success",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-false-error-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "error",
|
||||||
|
"sourceHandle": "fail-branch",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_default_value_continue_on_error():
|
||||||
|
error_code = """
|
||||||
|
def main() -> dict:
|
||||||
|
return {
|
||||||
|
"result": 1 / 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_code_node(
|
||||||
|
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_fail_branch_continue_on_error():
|
||||||
|
error_code = """
|
||||||
|
def main() -> dict:
|
||||||
|
return {
|
||||||
|
"result": 1 / 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_code_node(error_code),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_node_default_value_continue_on_error():
|
||||||
|
"""Test HTTP node with default value error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(
|
||||||
|
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
|
||||||
|
for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_node_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_default_value_continue_on_error():
|
||||||
|
"""Test tool node with default value error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_tool_node(
|
||||||
|
"default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_tool_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_default_value_continue_on_error():
|
||||||
|
"""Test LLM node with default value error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_llm_node(
|
||||||
|
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_fail_branch_continue_on_error():
|
||||||
|
"""Test LLM node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_llm_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_status_code_error_http_node_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_pool_error_type_variable():
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
list(graph_engine.run())
|
||||||
|
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
|
||||||
|
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
|
||||||
|
assert error_message != None
|
||||||
|
assert error_type.value == "HTTPResponseCodeError"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_node_in_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES[:-1],
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
Loading…
x
Reference in New Issue
Block a user