Feat: continue on error (#11458)

Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
This commit is contained in:
Novice 2024-12-11 14:22:42 +08:00 committed by GitHub
parent bec5451f12
commit 79a710ce98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1211 additions and 80 deletions

View File

@ -19,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -31,6 +32,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -317,7 +319,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
@ -384,6 +386,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
@ -401,6 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
yield self._workflow_finish_to_stream_response(

View File

@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager):
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()

View File

@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -26,6 +27,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -276,7 +278,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
@ -345,22 +347,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
@ -368,6 +368,60 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
handle_args = {
"workflow_run": workflow_run,
"start_at": graph_runtime_state.start_at,
"total_tokens": graph_runtime_state.total_tokens,
"total_steps": graph_runtime_state.node_run_steps,
"status": WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
}
workflow_run = self._handle_workflow_run_failed(**handle_args)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
handle_args = {
"workflow_run": workflow_run,
"start_at": graph_runtime_state.start_at,
"total_tokens": graph_runtime_state.total_tokens,
"total_steps": graph_runtime_state.node_run_steps,
"status": WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count,
}
workflow_run = self._handle_workflow_run_partial_success(**handle_args)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)

View File

@ -8,6 +8,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -18,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@ -25,6 +27,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@ -32,6 +35,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@ -176,8 +180,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@ -253,6 +261,36 @@ class WorkflowBasedAppRunner(AppRunner):
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(

View File

@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
NODE_EXCEPTION = "node_exception"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
@ -237,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
exceptions_count: int
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
class QueueNodeStartedEvent(AppQueueEvent):
@ -331,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
"""
event: QueueEvent = QueueEvent.NODE_EXCEPTION
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity

View File

@ -213,6 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED

View File

@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@ -164,6 +165,55 @@ class WorkflowCycleManage:
return workflow_run
def _handle_workflow_run_partial_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
db.session.close()
return workflow_run
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
@ -174,6 +224,7 @@ class WorkflowCycleManage:
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
@ -193,7 +244,7 @@ class WorkflowCycleManage:
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
running_workflow_node_executions = (
@ -318,7 +369,7 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@ -337,7 +388,11 @@ class WorkflowCycleManage:
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
@ -351,8 +406,11 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
@ -433,6 +491,7 @@ class WorkflowCycleManage:
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
exceptions_count=workflow_run.exceptions_count,
),
)
@ -483,7 +542,10 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:

View File

@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
class MaxRetriesExceededError(Exception):
"""Raised when the maximum number of retries is exceeded."""
pass
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@ -64,7 +70,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

View File

@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunPartialSucceededEvent):
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):

View File

@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
class NodeRunResult(BaseModel):
@ -43,3 +44,4 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed

View File

@ -33,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None
###########################################
@ -83,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")

View File

@ -64,13 +64,21 @@ class Graph(BaseModel):
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
@ -90,8 +98,16 @@ class Graph(BaseModel):
# parse run condition
run_condition = None
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
@ -100,13 +116,6 @@ class Graph(BaseModel):
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}

View File

@ -15,6 +15,7 @@ class RouteNodeState(BaseModel):
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
@ -51,7 +52,11 @@ class RouteNodeState(BaseModel):
:param run_result: run result
"""
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -59,6 +64,9 @@ class RouteNodeState(BaseModel):
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")

View File

@ -5,21 +5,23 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from typing import Any, Optional, cast
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@ -36,7 +38,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
@ -128,6 +132,7 @@ class GraphEngine:
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@ -140,13 +145,17 @@ class GraphEngine:
)
# run graph
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
)
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.",
exceptions_count=len(handle_exceptions),
)
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
@ -172,19 +181,24 @@ class GraphEngine:
].strip()
except Exception as e:
logger.exception("Graph run failed")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
return
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
# count exceptions to determine partial success
if len(handle_exceptions) > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
)
else:
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
self._release_thread()
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
self._release_thread()
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
self._release_thread()
raise e
@ -198,6 +212,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
@ -242,7 +257,7 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
try:
# run node
generator = self._run_node(
@ -252,6 +267,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
@ -301,7 +317,12 @@ class GraphEngine:
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -334,7 +355,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue
edge = sub_edge_mappings[0]
edge = cast(GraphEdge, sub_edge_mappings[0])
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@ -355,6 +376,7 @@ class GraphEngine:
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
@ -369,11 +391,18 @@ class GraphEngine:
break
next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
@ -395,6 +424,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
@ -438,6 +468,7 @@ class GraphEngine:
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)
@ -481,6 +512,7 @@ class GraphEngine:
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> None:
"""
Run parallel nodes
@ -502,6 +534,7 @@ class GraphEngine:
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
@ -548,6 +581,7 @@ class GraphEngine:
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
@ -587,19 +621,55 @@ class GraphEngine:
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
@ -735,6 +805,56 @@ class GraphEngine:
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance
def _handle_continue_on_error(
self,
node_instance: BaseNode[BaseNodeData],
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True
:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error)
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,
outputs={
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
return error_result
class GraphRunFailedError(Exception):
def __init__(self, error: str):

View File

@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -148,13 +148,18 @@ class AnswerStreamGeneratorRouter:
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
if source_node_type in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}:
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(

View File

@ -6,6 +6,7 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunExceptionEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@ -50,7 +51,7 @@ class AnswerStreamProcessor(StreamProcessor):
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished

View File

@ -1,14 +1,124 @@
import json
from abc import ABC
from typing import Optional
from enum import StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from core.workflow.nodes.base.exc import DefaultValueTypeError
from core.workflow.nodes.enums import ErrorStrategy
class DefaultValueType(StrEnum):
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY_NUMBER = "array[number]"
ARRAY_STRING = "array[string]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILES = "array[file]"
NumberType = Union[int, float]
class DefaultValue(BaseModel):
value: Any
type: DefaultValueType
key: str
@staticmethod
def _parse_json(value: str) -> Any:
"""Unified JSON parsing handler"""
try:
return json.loads(value)
except json.JSONDecodeError:
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
"""Unified array type validation"""
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
@staticmethod
def _convert_number(value: str) -> float:
"""Unified number conversion handler"""
try:
return float(value)
except ValueError:
raise DefaultValueTypeError(f"Cannot convert to number: {value}")
@model_validator(mode="after")
def validate_value_type(self) -> "DefaultValue":
if self.type is None:
raise DefaultValueTypeError("type field is required")
# Type validation configuration
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,
},
DefaultValueType.NUMBER: {
"type": NumberType,
"converter": self._convert_number,
},
DefaultValueType.OBJECT: {
"type": dict,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_NUMBER: {
"type": list,
"element_type": NumberType,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_STRING: {
"type": list,
"element_type": str,
"converter": self._parse_json,
},
DefaultValueType.ARRAY_OBJECT: {
"type": list,
"element_type": dict,
"converter": self._parse_json,
},
}
validator = type_validators.get(self.type)
if not validator:
if self.type == DefaultValueType.ARRAY_FILES:
# Handle files type
return self
raise DefaultValueTypeError(f"Unsupported type: {self.type}")
# Handle string input cases
if isinstance(self.value, str) and self.type != DefaultValueType.STRING:
self.value = validator["converter"](self.value)
# Validate base type
if not isinstance(self.value, validator["type"]):
raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}")
# Validate array element types
if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]):
raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}")
return self
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
@property
def default_value_dict(self):
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
class BaseIterationNodeData(BaseNodeData):
start_node_id: Optional[str] = None

View File

@ -0,0 +1,10 @@
class BaseNodeError(Exception):
"""Base class for node errors."""
pass
class DefaultValueTypeError(BaseNodeError):
"""Raised when the default value type is invalid."""
pass

View File

@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@ -72,10 +72,7 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run()
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError")
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result)
@ -137,3 +134,12 @@ class BaseNode(Generic[GenericNodeData]):
:return:
"""
return self._node_type
@property
def should_continue_on_error(self) -> bool:
"""judge if should continue on error
Returns:
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE

View File

@ -61,7 +61,9 @@ class CodeNode(BaseNode[CodeNodeData]):
# Transform result
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
)
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)

View File

@ -22,3 +22,16 @@ class NodeType(StrEnum):
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]

View File

@ -21,6 +21,7 @@ from .entities import (
from .exc import (
AuthorizationConfigError,
FileFetchError,
HttpRequestNodeError,
InvalidHttpMethodError,
ResponseSizeError,
)
@ -208,8 +209,10 @@ class Executor:
"follow_redirects": True,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
response = getattr(ssrf_proxy, self.method)(**request_args)
try:
response = getattr(ssrf_proxy, self.method)(**request_args)
except ssrf_proxy.MaxRetriesExceededError as e:
raise HttpRequestNodeError(str(e))
return response
def invoke(self) -> Response:

View File

@ -65,6 +65,21 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
"status_code": response.status_code,
"body": response.text if not files else "",
"headers": response.headers,
"files": files,
},
process_data={
"request": http_executor.to_log(),
},
error=f"Request failed with status code {response.status_code}",
error_type="HTTPResponseCodeError",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
@ -83,6 +98,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
process_data=process_data,
error_type=type(e).__name__,
)
@staticmethod

View File

@ -193,6 +193,7 @@ class LLMNode(BaseNode[LLMNodeData]):
error=str(e),
inputs=node_inputs,
process_data=process_data,
error_type=type(e).__name__,
)
)
return

View File

@ -139,7 +139,7 @@ class QuestionClassifierNode(LLMNode):
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
outputs = {"class_name": category_name}
outputs = {"class_name": category_name, "class_id": category_id}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,

View File

@ -56,6 +56,7 @@ class ToolNode(BaseNode[ToolNodeData]):
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
# get parameters
@ -89,6 +90,7 @@ class ToolNode(BaseNode[ToolNodeData]):
NodeRunMetadataKey.TOOL_INFO: tool_info,
},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
# convert tool messages

View File

@ -14,6 +14,7 @@ workflow_run_for_log_fields = {
"total_steps": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}
workflow_run_for_list_fields = {
@ -27,6 +28,7 @@ workflow_run_for_list_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}
advanced_chat_workflow_run_for_list_fields = {
@ -42,6 +44,7 @@ advanced_chat_workflow_run_for_list_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}
advanced_chat_workflow_run_pagination_fields = {
@ -73,6 +76,7 @@ workflow_run_detail_fields = {
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
}
workflow_run_node_execution_fields = {

View File

@ -0,0 +1,33 @@
"""add exceptions_count field to WorkflowRun model
Revision ID: cf8f4fc45278
Revises: 01d6889832f7
Create Date: 2024-11-28 05:53:21.576178
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'cf8f4fc45278'
down_revision = '01d6889832f7'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
batch_op.drop_column('exceptions_count')
# ### end Alembic commands ###

View File

@ -325,6 +325,7 @@ class WorkflowRunStatus(StrEnum):
SUCCEEDED = "succeeded"
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCESSED = "partial-succeeded"
@classmethod
def value_of(cls, value: str) -> "WorkflowRunStatus":
@ -395,7 +396,7 @@ class WorkflowRun(db.Model):
version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text)
inputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
@ -405,6 +406,7 @@ class WorkflowRun(db.Model):
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):
@ -464,6 +466,7 @@ class WorkflowRun(db.Model):
"created_by": self.created_by,
"created_at": self.created_at,
"finished_at": self.finished_at,
"exceptions_count": self.exceptions_count,
}
@classmethod
@ -489,6 +492,7 @@ class WorkflowRun(db.Model):
created_by=data.get("created_by"),
created_at=data.get("created_at"),
finished_at=data.get("finished_at"),
exceptions_count=data.get("exceptions_count"),
)
@ -522,6 +526,7 @@ class WorkflowNodeExecutionStatus(Enum):
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":

View File

@ -2,7 +2,7 @@ import json
import time
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Optional
from typing import Optional, cast
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -11,6 +11,9 @@ from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
@ -225,7 +228,7 @@ class WorkflowService:
user_inputs=user_inputs,
user_id=account.id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
@ -237,8 +240,35 @@ class WorkflowService:
if not node_run_result:
raise ValueError("Node run failed with no run result")
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
@ -260,7 +290,6 @@ class WorkflowService:
workflow_node_execution.created_by = account.id
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if run_succeeded and node_run_result:
# create workflow node execution
inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None
@ -277,7 +306,11 @@ class WorkflowService:
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value
workflow_node_execution.error = node_run_result.error
else:
# create workflow node execution
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value

View File

@ -0,0 +1,502 @@
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.enums import UserFrom
from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
"data": {
"outputs": {"result": {"type": "number"}},
"error_strategy": error_strategy,
"title": "code",
"variables": [],
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_http_node(
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
):
"""Helper method to create a http node configuration"""
authorization = (
{
"type": "api-key",
"config": {
"type": "basic",
"api_key": "ak-xxx",
"header": "api-key",
},
}
if authorization_success
else {
"type": "api-key",
# missing config field
}
)
node = {
"id": "node",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "http://example.com",
"authorization": authorization,
"headers": "X-Header:123",
"params": "A:b",
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a http node configuration"""
node = {
"id": "node",
"data": {
"type": "http-request",
"title": "HTTP Request",
"desc": "",
"variables": [],
"method": "get",
"url": "https://api.github.com/issues",
"authorization": {"type": "no-auth", "config": None},
"headers": "",
"params": "",
"body": {"type": "none", "data": []},
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a tool node configuration"""
node = {
"id": "node",
"data": {
"title": "a",
"desc": "a",
"provider_id": "maths",
"provider_type": "builtin",
"provider_name": "maths",
"tool_name": "eval_expression",
"tool_label": "eval_expression",
"tool_configurations": {},
"tool_parameters": {
"expression": {
"type": "variable",
"value": ["1", "123", "args1"],
}
},
"type": "tool",
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
"""Helper method to create a llm node configuration"""
node = {
"id": "node",
"data": {
"title": "123",
"type": "llm",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
"error_strategy": error_strategy,
},
}
if default_value:
node["data"]["default_value"] = default_value
return node
@staticmethod
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config)
variable_pool = {
"system_variables": {
SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
"user_inputs": user_inputs or {"uid": "takato"},
}
return GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
FAIL_BRANCH_EDGES = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-true-success-target",
"source": "node",
"target": "success",
"sourceHandle": "source",
},
{
"id": "node-false-error-target",
"source": "node",
"target": "error",
"sourceHandle": "fail-branch",
},
]
def test_code_default_value_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_code_node(
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_code_fail_branch_continue_on_error():
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_code_node(error_code),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
)
def test_http_node_default_value_continue_on_error():
"""Test HTTP node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_tool_node_default_value_continue_on_error():
"""Test tool node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_tool_node(
"default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_tool_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_tool_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_llm_node_default_value_continue_on_error():
"""Test LLM node with default value error strategy"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_llm_node(
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_llm_node_fail_branch_continue_on_error():
"""Test LLM node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_status_code_error_http_node_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
def test_variable_pool_error_type_variable():
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
"id": "error",
},
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
list(graph_engine.run())
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
assert error_message != None
assert error_type.value == "HTTPResponseCodeError"
def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0