mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 01:39:04 +08:00
Feat: Retry on node execution errors (#11871)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
This commit is contained in:
parent
f6247fe67c
commit
7abc7fa573
@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
QueueParallelBranchRunFailedEvent,
|
||||||
@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
yield response
|
||||||
|
elif isinstance(
|
||||||
|
event,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
|
):
|
||||||
|
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||||
|
workflow_run=workflow_run, event=event
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self._workflow_node_retry_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
yield response
|
yield response
|
||||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||||
|
@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
QueueParallelBranchRunFailedEvent,
|
||||||
@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_node_execution=workflow_node_execution,
|
workflow_node_execution=workflow_node_execution,
|
||||||
)
|
)
|
||||||
|
|
||||||
if node_failed_response:
|
if node_failed_response:
|
||||||
yield node_failed_response
|
yield node_failed_response
|
||||||
|
elif isinstance(
|
||||||
|
event,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
|
):
|
||||||
|
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||||
|
workflow_run=workflow_run, event=event
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self._workflow_node_retry_to_stream_response(
|
||||||
|
event=event,
|
||||||
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
workflow_node_execution=workflow_node_execution,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
yield response
|
||||||
|
|
||||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||||
if not workflow_run:
|
if not workflow_run:
|
||||||
raise Exception("Workflow run not initialized.")
|
raise Exception("Workflow run not initialized.")
|
||||||
|
@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
QueueParallelBranchRunFailedEvent,
|
||||||
@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
|
NodeRunRetryEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(event, NodeRunRetryEvent):
|
||||||
|
self._publish_event(
|
||||||
|
QueueNodeRetryEvent(
|
||||||
|
node_execution_id=event.id,
|
||||||
|
node_id=event.node_id,
|
||||||
|
node_type=event.node_type,
|
||||||
|
node_data=event.node_data,
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
parent_parallel_id=event.parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||||
|
start_at=event.start_at,
|
||||||
|
inputs=event.route_node_state.node_run_result.inputs
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
process_data=event.route_node_state.node_run_result.process_data
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
outputs=event.route_node_state.node_run_result.outputs
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
error=event.error,
|
||||||
|
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||||
|
if event.route_node_state.node_run_result
|
||||||
|
else {},
|
||||||
|
in_iteration_id=event.in_iteration_id,
|
||||||
|
retry_index=event.retry_index,
|
||||||
|
start_index=event.start_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||||
"""
|
"""
|
||||||
|
@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
|
|||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
PING = "ping"
|
PING = "ping"
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
|
RETRY = "retry"
|
||||||
|
|
||||||
|
|
||||||
class AppQueueEvent(BaseModel):
|
class AppQueueEvent(BaseModel):
|
||||||
@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||||||
iteration_duration_map: Optional[dict[str, float]] = None
|
iteration_duration_map: Optional[dict[str, float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class QueueNodeRetryEvent(AppQueueEvent):
|
||||||
|
"""QueueNodeRetryEvent entity"""
|
||||||
|
|
||||||
|
event: QueueEvent = QueueEvent.RETRY
|
||||||
|
|
||||||
|
node_execution_id: str
|
||||||
|
node_id: str
|
||||||
|
node_type: NodeType
|
||||||
|
node_data: BaseNodeData
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
"""parallel id if node is in parallel"""
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parallel start node id if node is in parallel"""
|
||||||
|
parent_parallel_id: Optional[str] = None
|
||||||
|
"""parent parallel id if node is in parallel"""
|
||||||
|
parent_parallel_start_node_id: Optional[str] = None
|
||||||
|
"""parent parallel start node id if node is in parallel"""
|
||||||
|
in_iteration_id: Optional[str] = None
|
||||||
|
"""iteration id if node is in iteration"""
|
||||||
|
start_at: datetime
|
||||||
|
|
||||||
|
inputs: Optional[dict[str, Any]] = None
|
||||||
|
process_data: Optional[dict[str, Any]] = None
|
||||||
|
outputs: Optional[dict[str, Any]] = None
|
||||||
|
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||||
|
|
||||||
|
error: str
|
||||||
|
retry_index: int # retry index
|
||||||
|
start_index: int # start index
|
||||||
|
|
||||||
|
|
||||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||||
"""
|
"""
|
||||||
QueueNodeInIterationFailedEvent entity
|
QueueNodeInIterationFailedEvent entity
|
||||||
|
@ -52,6 +52,7 @@ class StreamEvent(Enum):
|
|||||||
WORKFLOW_FINISHED = "workflow_finished"
|
WORKFLOW_FINISHED = "workflow_finished"
|
||||||
NODE_STARTED = "node_started"
|
NODE_STARTED = "node_started"
|
||||||
NODE_FINISHED = "node_finished"
|
NODE_FINISHED = "node_finished"
|
||||||
|
NODE_RETRY = "node_retry"
|
||||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||||
ITERATION_STARTED = "iteration_started"
|
ITERATION_STARTED = "iteration_started"
|
||||||
@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRetryStreamResponse(StreamResponse):
|
||||||
|
"""
|
||||||
|
NodeFinishStreamResponse entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""
|
||||||
|
Data entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
node_id: str
|
||||||
|
node_type: str
|
||||||
|
title: str
|
||||||
|
index: int
|
||||||
|
predecessor_node_id: Optional[str] = None
|
||||||
|
inputs: Optional[dict] = None
|
||||||
|
process_data: Optional[dict] = None
|
||||||
|
outputs: Optional[dict] = None
|
||||||
|
status: str
|
||||||
|
error: Optional[str] = None
|
||||||
|
elapsed_time: float
|
||||||
|
execution_metadata: Optional[dict] = None
|
||||||
|
created_at: int
|
||||||
|
finished_at: int
|
||||||
|
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||||
|
parallel_id: Optional[str] = None
|
||||||
|
parallel_start_node_id: Optional[str] = None
|
||||||
|
parent_parallel_id: Optional[str] = None
|
||||||
|
parent_parallel_start_node_id: Optional[str] = None
|
||||||
|
iteration_id: Optional[str] = None
|
||||||
|
retry_index: int = 0
|
||||||
|
|
||||||
|
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||||
|
workflow_run_id: str
|
||||||
|
data: Data
|
||||||
|
|
||||||
|
def to_ignore_detail_dict(self):
|
||||||
|
return {
|
||||||
|
"event": self.event.value,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"workflow_run_id": self.workflow_run_id,
|
||||||
|
"data": {
|
||||||
|
"id": self.data.id,
|
||||||
|
"node_id": self.data.node_id,
|
||||||
|
"node_type": self.data.node_type,
|
||||||
|
"title": self.data.title,
|
||||||
|
"index": self.data.index,
|
||||||
|
"predecessor_node_id": self.data.predecessor_node_id,
|
||||||
|
"inputs": None,
|
||||||
|
"process_data": None,
|
||||||
|
"outputs": None,
|
||||||
|
"status": self.data.status,
|
||||||
|
"error": None,
|
||||||
|
"elapsed_time": self.data.elapsed_time,
|
||||||
|
"execution_metadata": None,
|
||||||
|
"created_at": self.data.created_at,
|
||||||
|
"finished_at": self.data.finished_at,
|
||||||
|
"files": [],
|
||||||
|
"parallel_id": self.data.parallel_id,
|
||||||
|
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||||
|
"parent_parallel_id": self.data.parent_parallel_id,
|
||||||
|
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||||
|
"iteration_id": self.data.iteration_id,
|
||||||
|
"retry_index": self.data.retry_index,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||||
"""
|
"""
|
||||||
ParallelBranchStartStreamResponse entity
|
ParallelBranchStartStreamResponse entity
|
||||||
|
@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeExceptionEvent,
|
QueueNodeExceptionEvent,
|
||||||
QueueNodeFailedEvent,
|
QueueNodeFailedEvent,
|
||||||
QueueNodeInIterationFailedEvent,
|
QueueNodeInIterationFailedEvent,
|
||||||
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueParallelBranchRunFailedEvent,
|
QueueParallelBranchRunFailedEvent,
|
||||||
@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
|
|||||||
IterationNodeNextStreamResponse,
|
IterationNodeNextStreamResponse,
|
||||||
IterationNodeStartStreamResponse,
|
IterationNodeStartStreamResponse,
|
||||||
NodeFinishStreamResponse,
|
NodeFinishStreamResponse,
|
||||||
|
NodeRetryStreamResponse,
|
||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
ParallelBranchFinishedStreamResponse,
|
ParallelBranchFinishedStreamResponse,
|
||||||
ParallelBranchStartStreamResponse,
|
ParallelBranchStartStreamResponse,
|
||||||
@ -423,6 +425,52 @@ class WorkflowCycleManage:
|
|||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
def _handle_workflow_node_execution_retried(
|
||||||
|
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||||
|
) -> WorkflowNodeExecution:
|
||||||
|
"""
|
||||||
|
Workflow node execution failed
|
||||||
|
:param event: queue node failed event
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
created_at = event.start_at
|
||||||
|
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
elapsed_time = (finished_at - created_at).total_seconds()
|
||||||
|
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||||
|
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||||
|
|
||||||
|
workflow_node_execution = WorkflowNodeExecution()
|
||||||
|
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||||
|
workflow_node_execution.app_id = workflow_run.app_id
|
||||||
|
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||||
|
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||||
|
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||||
|
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||||
|
workflow_node_execution.node_id = event.node_id
|
||||||
|
workflow_node_execution.node_type = event.node_type.value
|
||||||
|
workflow_node_execution.title = event.node_data.title
|
||||||
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||||
|
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||||
|
workflow_node_execution.created_by = workflow_run.created_by
|
||||||
|
workflow_node_execution.created_at = created_at
|
||||||
|
workflow_node_execution.finished_at = finished_at
|
||||||
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
workflow_node_execution.error = event.error
|
||||||
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
|
workflow_node_execution.execution_metadata = json.dumps(
|
||||||
|
{
|
||||||
|
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
workflow_node_execution.index = event.start_index
|
||||||
|
|
||||||
|
db.session.add(workflow_node_execution)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(workflow_node_execution)
|
||||||
|
|
||||||
|
return workflow_node_execution
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# to stream responses #
|
# to stream responses #
|
||||||
#################################################
|
#################################################
|
||||||
@ -587,6 +635,51 @@ class WorkflowCycleManage:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _workflow_node_retry_to_stream_response(
|
||||||
|
self,
|
||||||
|
event: QueueNodeRetryEvent,
|
||||||
|
task_id: str,
|
||||||
|
workflow_node_execution: WorkflowNodeExecution,
|
||||||
|
) -> Optional[NodeFinishStreamResponse]:
|
||||||
|
"""
|
||||||
|
Workflow node finish to stream response.
|
||||||
|
:param event: queue node succeeded or failed event
|
||||||
|
:param task_id: task id
|
||||||
|
:param workflow_node_execution: workflow node execution
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return NodeRetryStreamResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||||
|
data=NodeRetryStreamResponse.Data(
|
||||||
|
id=workflow_node_execution.id,
|
||||||
|
node_id=workflow_node_execution.node_id,
|
||||||
|
node_type=workflow_node_execution.node_type,
|
||||||
|
index=workflow_node_execution.index,
|
||||||
|
title=workflow_node_execution.title,
|
||||||
|
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||||
|
inputs=workflow_node_execution.inputs_dict,
|
||||||
|
process_data=workflow_node_execution.process_data_dict,
|
||||||
|
outputs=workflow_node_execution.outputs_dict,
|
||||||
|
status=workflow_node_execution.status,
|
||||||
|
error=workflow_node_execution.error,
|
||||||
|
elapsed_time=workflow_node_execution.elapsed_time,
|
||||||
|
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||||
|
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||||
|
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||||
|
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||||
|
parallel_id=event.parallel_id,
|
||||||
|
parallel_start_node_id=event.parallel_start_node_id,
|
||||||
|
parent_parallel_id=event.parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||||
|
iteration_id=event.in_iteration_id,
|
||||||
|
retry_index=event.retry_index,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _workflow_parallel_branch_start_to_stream_response(
|
def _workflow_parallel_branch_start_to_stream_response(
|
||||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||||
) -> ParallelBranchStartStreamResponse:
|
) -> ParallelBranchStartStreamResponse:
|
||||||
|
@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
retries = 0
|
retries = 0
|
||||||
stream = kwargs.pop("stream", False)
|
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
if dify_config.SSRF_PROXY_ALL_URL:
|
if dify_config.SSRF_PROXY_ALL_URL:
|
||||||
|
@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
|
|||||||
|
|
||||||
error: Optional[str] = None # error message if status is failed
|
error: Optional[str] = None # error message if status is failed
|
||||||
error_type: Optional[str] = None # error type if status is failed
|
error_type: Optional[str] = None # error type if status is failed
|
||||||
|
|
||||||
|
# single step node run retry
|
||||||
|
retry_index: int = 0
|
||||||
|
@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
|||||||
error: str = Field(..., description="error")
|
error: str = Field(..., description="error")
|
||||||
|
|
||||||
|
|
||||||
|
class NodeRunRetryEvent(BaseNodeEvent):
|
||||||
|
error: str = Field(..., description="error")
|
||||||
|
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||||
|
start_at: datetime = Field(..., description="retry start time")
|
||||||
|
start_index: int = Field(..., description="retry start index")
|
||||||
|
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
# Parallel Branch Events
|
# Parallel Branch Events
|
||||||
###########################################
|
###########################################
|
||||||
|
@ -5,6 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from concurrent.futures import ThreadPoolExecutor, wait
|
from concurrent.futures import ThreadPoolExecutor, wait
|
||||||
from copy import copy, deepcopy
|
from copy import copy, deepcopy
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
|
NodeRunRetryEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
@ -581,7 +583,7 @@ class GraphEngine:
|
|||||||
|
|
||||||
def _run_node(
|
def _run_node(
|
||||||
self,
|
self,
|
||||||
node_instance: BaseNode,
|
node_instance: BaseNode[BaseNodeData],
|
||||||
route_node_state: RouteNodeState,
|
route_node_state: RouteNodeState,
|
||||||
parallel_id: Optional[str] = None,
|
parallel_id: Optional[str] = None,
|
||||||
parallel_start_node_id: Optional[str] = None,
|
parallel_start_node_id: Optional[str] = None,
|
||||||
@ -607,36 +609,121 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
max_retries = node_instance.node_data.retry_config.max_retries
|
||||||
|
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||||
|
retries = 0
|
||||||
|
shoudl_continue_retry = True
|
||||||
|
while shoudl_continue_retry and retries <= max_retries:
|
||||||
|
try:
|
||||||
|
# run node
|
||||||
|
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
generator = node_instance.run()
|
||||||
|
for item in generator:
|
||||||
|
if isinstance(item, GraphEngineEvent):
|
||||||
|
if isinstance(item, BaseIterationEvent):
|
||||||
|
# add parallel info to iteration event
|
||||||
|
item.parallel_id = parallel_id
|
||||||
|
item.parallel_start_node_id = parallel_start_node_id
|
||||||
|
item.parent_parallel_id = parent_parallel_id
|
||||||
|
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||||
|
|
||||||
try:
|
yield item
|
||||||
# run node
|
else:
|
||||||
generator = node_instance.run()
|
if isinstance(item, RunCompletedEvent):
|
||||||
for item in generator:
|
run_result = item.run_result
|
||||||
if isinstance(item, GraphEngineEvent):
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
if isinstance(item, BaseIterationEvent):
|
if (
|
||||||
# add parallel info to iteration event
|
retries == max_retries
|
||||||
item.parallel_id = parallel_id
|
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||||
item.parallel_start_node_id = parallel_start_node_id
|
and run_result.outputs
|
||||||
item.parent_parallel_id = parent_parallel_id
|
and not node_instance.should_continue_on_error
|
||||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
):
|
||||||
|
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
if node_instance.should_retry and retries < max_retries:
|
||||||
|
retries += 1
|
||||||
|
self.graph_runtime_state.node_run_steps += 1
|
||||||
|
route_node_state.node_run_result = run_result
|
||||||
|
yield NodeRunRetryEvent(
|
||||||
|
id=node_instance.id,
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
error=run_result.error,
|
||||||
|
retry_index=retries,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
parent_parallel_id=parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
start_at=retry_start_at,
|
||||||
|
start_index=self.graph_runtime_state.node_run_steps,
|
||||||
|
)
|
||||||
|
time.sleep(retry_interval)
|
||||||
|
continue
|
||||||
|
route_node_state.set_finished(run_result=run_result)
|
||||||
|
|
||||||
yield item
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
else:
|
if node_instance.should_continue_on_error:
|
||||||
if isinstance(item, RunCompletedEvent):
|
# if run failed, handle error
|
||||||
run_result = item.run_result
|
run_result = self._handle_continue_on_error(
|
||||||
route_node_state.set_finished(run_result=run_result)
|
node_instance,
|
||||||
|
item.run_result,
|
||||||
|
self.graph_runtime_state.variable_pool,
|
||||||
|
handle_exceptions=handle_exceptions,
|
||||||
|
)
|
||||||
|
route_node_state.node_run_result = run_result
|
||||||
|
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||||
|
if run_result.outputs:
|
||||||
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
|
# append variables to variable pool recursively
|
||||||
|
self._append_variables_recursively(
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
variable_key_list=[variable_key],
|
||||||
|
variable_value=variable_value,
|
||||||
|
)
|
||||||
|
yield NodeRunExceptionEvent(
|
||||||
|
error=run_result.error or "System Error",
|
||||||
|
id=node_instance.id,
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
parent_parallel_id=parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
)
|
||||||
|
shoudl_continue_retry = False
|
||||||
|
else:
|
||||||
|
yield NodeRunFailedEvent(
|
||||||
|
error=route_node_state.failed_reason or "Unknown error.",
|
||||||
|
id=node_instance.id,
|
||||||
|
node_id=node_instance.node_id,
|
||||||
|
node_type=node_instance.node_type,
|
||||||
|
node_data=node_instance.node_data,
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
|
parent_parallel_id=parent_parallel_id,
|
||||||
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
|
)
|
||||||
|
shoudl_continue_retry = False
|
||||||
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||||
|
node_instance.node_id
|
||||||
|
):
|
||||||
|
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||||
|
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||||
|
# plus state total_tokens
|
||||||
|
self.graph_runtime_state.total_tokens += int(
|
||||||
|
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
if run_result.llm_usage:
|
||||||
if node_instance.should_continue_on_error:
|
# use the latest usage
|
||||||
# if run failed, handle error
|
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||||
run_result = self._handle_continue_on_error(
|
|
||||||
node_instance,
|
# append node output variables to variable pool
|
||||||
item.run_result,
|
|
||||||
self.graph_runtime_state.variable_pool,
|
|
||||||
handle_exceptions=handle_exceptions,
|
|
||||||
)
|
|
||||||
route_node_state.node_run_result = run_result
|
|
||||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
|
||||||
if run_result.outputs:
|
if run_result.outputs:
|
||||||
for variable_key, variable_value in run_result.outputs.items():
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
# append variables to variable pool recursively
|
# append variables to variable pool recursively
|
||||||
@ -645,21 +732,23 @@ class GraphEngine:
|
|||||||
variable_key_list=[variable_key],
|
variable_key_list=[variable_key],
|
||||||
variable_value=variable_value,
|
variable_value=variable_value,
|
||||||
)
|
)
|
||||||
yield NodeRunExceptionEvent(
|
|
||||||
error=run_result.error or "System Error",
|
# add parallel info to run result metadata
|
||||||
id=node_instance.id,
|
if parallel_id and parallel_start_node_id:
|
||||||
node_id=node_instance.node_id,
|
if not run_result.metadata:
|
||||||
node_type=node_instance.node_type,
|
run_result.metadata = {}
|
||||||
node_data=node_instance.node_data,
|
|
||||||
route_node_state=route_node_state,
|
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||||
parallel_id=parallel_id,
|
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id
|
||||||
parent_parallel_id=parent_parallel_id,
|
)
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
if parent_parallel_id and parent_parallel_start_node_id:
|
||||||
)
|
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||||
else:
|
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||||
yield NodeRunFailedEvent(
|
parent_parallel_start_node_id
|
||||||
error=route_node_state.failed_reason or "Unknown error.",
|
)
|
||||||
|
|
||||||
|
yield NodeRunSucceededEvent(
|
||||||
id=node_instance.id,
|
id=node_instance.id,
|
||||||
node_id=node_instance.node_id,
|
node_id=node_instance.node_id,
|
||||||
node_type=node_instance.node_type,
|
node_type=node_instance.node_type,
|
||||||
@ -670,108 +759,59 @@ class GraphEngine:
|
|||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
)
|
)
|
||||||
|
shoudl_continue_retry = False
|
||||||
|
|
||||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
break
|
||||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
elif isinstance(item, RunStreamChunkEvent):
|
||||||
node_instance.node_id
|
yield NodeRunStreamChunkEvent(
|
||||||
):
|
|
||||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
|
||||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
|
||||||
# plus state total_tokens
|
|
||||||
self.graph_runtime_state.total_tokens += int(
|
|
||||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
if run_result.llm_usage:
|
|
||||||
# use the latest usage
|
|
||||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
|
||||||
|
|
||||||
# append node output variables to variable pool
|
|
||||||
if run_result.outputs:
|
|
||||||
for variable_key, variable_value in run_result.outputs.items():
|
|
||||||
# append variables to variable pool recursively
|
|
||||||
self._append_variables_recursively(
|
|
||||||
node_id=node_instance.node_id,
|
|
||||||
variable_key_list=[variable_key],
|
|
||||||
variable_value=variable_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# add parallel info to run result metadata
|
|
||||||
if parallel_id and parallel_start_node_id:
|
|
||||||
if not run_result.metadata:
|
|
||||||
run_result.metadata = {}
|
|
||||||
|
|
||||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
|
||||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
|
||||||
if parent_parallel_id and parent_parallel_start_node_id:
|
|
||||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
|
||||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
|
||||||
parent_parallel_start_node_id
|
|
||||||
)
|
|
||||||
|
|
||||||
yield NodeRunSucceededEvent(
|
|
||||||
id=node_instance.id,
|
id=node_instance.id,
|
||||||
node_id=node_instance.node_id,
|
node_id=node_instance.node_id,
|
||||||
node_type=node_instance.node_type,
|
node_type=node_instance.node_type,
|
||||||
node_data=node_instance.node_data,
|
node_data=node_instance.node_data,
|
||||||
|
chunk_content=item.chunk_content,
|
||||||
|
from_variable_selector=item.from_variable_selector,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
)
|
)
|
||||||
|
elif isinstance(item, RunRetrieverResourceEvent):
|
||||||
break
|
yield NodeRunRetrieverResourceEvent(
|
||||||
elif isinstance(item, RunStreamChunkEvent):
|
id=node_instance.id,
|
||||||
yield NodeRunStreamChunkEvent(
|
node_id=node_instance.node_id,
|
||||||
id=node_instance.id,
|
node_type=node_instance.node_type,
|
||||||
node_id=node_instance.node_id,
|
node_data=node_instance.node_data,
|
||||||
node_type=node_instance.node_type,
|
retriever_resources=item.retriever_resources,
|
||||||
node_data=node_instance.node_data,
|
context=item.context,
|
||||||
chunk_content=item.chunk_content,
|
route_node_state=route_node_state,
|
||||||
from_variable_selector=item.from_variable_selector,
|
parallel_id=parallel_id,
|
||||||
route_node_state=route_node_state,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parallel_id=parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
)
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
except GenerateTaskStoppedError:
|
||||||
)
|
# trigger node run failed event
|
||||||
elif isinstance(item, RunRetrieverResourceEvent):
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
yield NodeRunRetrieverResourceEvent(
|
route_node_state.failed_reason = "Workflow stopped."
|
||||||
id=node_instance.id,
|
yield NodeRunFailedEvent(
|
||||||
node_id=node_instance.node_id,
|
error="Workflow stopped.",
|
||||||
node_type=node_instance.node_type,
|
id=node_instance.id,
|
||||||
node_data=node_instance.node_data,
|
node_id=node_instance.node_id,
|
||||||
retriever_resources=item.retriever_resources,
|
node_type=node_instance.node_type,
|
||||||
context=item.context,
|
node_data=node_instance.node_data,
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
parallel_start_node_id=parallel_start_node_id,
|
||||||
parent_parallel_id=parent_parallel_id,
|
parent_parallel_id=parent_parallel_id,
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedError:
|
return
|
||||||
# trigger node run failed event
|
except Exception as e:
|
||||||
route_node_state.status = RouteNodeState.Status.FAILED
|
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||||
route_node_state.failed_reason = "Workflow stopped."
|
raise e
|
||||||
yield NodeRunFailedEvent(
|
finally:
|
||||||
error="Workflow stopped.",
|
db.session.close()
|
||||||
id=node_instance.id,
|
|
||||||
node_id=node_instance.node_id,
|
|
||||||
node_type=node_instance.node_type,
|
|
||||||
node_data=node_instance.node_data,
|
|
||||||
route_node_state=route_node_state,
|
|
||||||
parallel_id=parallel_id,
|
|
||||||
parallel_start_node_id=parallel_start_node_id,
|
|
||||||
parent_parallel_id=parent_parallel_id,
|
|
||||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||||
"""
|
"""
|
||||||
|
@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class RetryConfig(BaseModel):
|
||||||
|
"""node retry config"""
|
||||||
|
|
||||||
|
max_retries: int = 0 # max retry times
|
||||||
|
retry_interval: int = 0 # retry interval in milliseconds
|
||||||
|
retry_enabled: bool = False # whether retry is enabled
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry_interval_seconds(self) -> float:
|
||||||
|
return self.retry_interval / 1000
|
||||||
|
|
||||||
|
|
||||||
class BaseNodeData(ABC, BaseModel):
|
class BaseNodeData(ABC, BaseModel):
|
||||||
title: str
|
title: str
|
||||||
desc: Optional[str] = None
|
desc: Optional[str] = None
|
||||||
error_strategy: Optional[ErrorStrategy] = None
|
error_strategy: Optional[ErrorStrategy] = None
|
||||||
default_value: Optional[list[DefaultValue]] = None
|
default_value: Optional[list[DefaultValue]] = None
|
||||||
version: str = "1"
|
version: str = "1"
|
||||||
|
retry_config: RetryConfig = RetryConfig()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value_dict(self):
|
def default_value_dict(self):
|
||||||
|
@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
|
|||||||
bool: if should continue on error
|
bool: if should continue on error
|
||||||
"""
|
"""
|
||||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_retry(self) -> bool:
|
||||||
|
"""judge if should retry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: if should retry
|
||||||
|
"""
|
||||||
|
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||||
|
@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||||
|
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from .event import (
|
||||||
|
ModelInvokeCompletedEvent,
|
||||||
|
RunCompletedEvent,
|
||||||
|
RunRetrieverResourceEvent,
|
||||||
|
RunRetryEvent,
|
||||||
|
RunStreamChunkEvent,
|
||||||
|
)
|
||||||
from .types import NodeEvent
|
from .types import NodeEvent
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -6,5 +12,6 @@ __all__ = [
|
|||||||
"NodeEvent",
|
"NodeEvent",
|
||||||
"RunCompletedEvent",
|
"RunCompletedEvent",
|
||||||
"RunRetrieverResourceEvent",
|
"RunRetrieverResourceEvent",
|
||||||
|
"RunRetryEvent",
|
||||||
"RunStreamChunkEvent",
|
"RunStreamChunkEvent",
|
||||||
]
|
]
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class RunCompletedEvent(BaseModel):
|
class RunCompletedEvent(BaseModel):
|
||||||
@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
usage: LLMUsage
|
usage: LLMUsage
|
||||||
finish_reason: str | None = None
|
finish_reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RunRetryEvent(BaseModel):
|
||||||
|
"""Node Run Retry event"""
|
||||||
|
|
||||||
|
error: str = Field(..., description="error")
|
||||||
|
retry_index: int = Field(..., description="Retry attempt number")
|
||||||
|
start_at: datetime = Field(..., description="Retry start time")
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStepRetryEvent(BaseModel):
|
||||||
|
"""Single step retry event"""
|
||||||
|
|
||||||
|
status: str = WorkflowNodeExecutionStatus.RETRY.value
|
||||||
|
|
||||||
|
inputs: dict | None = Field(..., description="input")
|
||||||
|
error: str = Field(..., description="error")
|
||||||
|
outputs: dict = Field(..., description="output")
|
||||||
|
retry_index: int = Field(..., description="Retry attempt number")
|
||||||
|
error: str = Field(..., description="error")
|
||||||
|
elapsed_time: float = Field(..., description="elapsed time")
|
||||||
|
execution_metadata: dict | None = Field(..., description="execution metadata")
|
||||||
|
@ -45,6 +45,7 @@ class Executor:
|
|||||||
headers: dict[str, str]
|
headers: dict[str, str]
|
||||||
auth: HttpRequestNodeAuthorization
|
auth: HttpRequestNodeAuthorization
|
||||||
timeout: HttpRequestNodeTimeout
|
timeout: HttpRequestNodeTimeout
|
||||||
|
max_retries: int
|
||||||
|
|
||||||
boundary: str
|
boundary: str
|
||||||
|
|
||||||
@ -54,6 +55,7 @@ class Executor:
|
|||||||
node_data: HttpRequestNodeData,
|
node_data: HttpRequestNodeData,
|
||||||
timeout: HttpRequestNodeTimeout,
|
timeout: HttpRequestNodeTimeout,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
|
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||||
):
|
):
|
||||||
# If authorization API key is present, convert the API key using the variable pool
|
# If authorization API key is present, convert the API key using the variable pool
|
||||||
if node_data.authorization.type == "api-key":
|
if node_data.authorization.type == "api-key":
|
||||||
@ -73,6 +75,7 @@ class Executor:
|
|||||||
self.files = None
|
self.files = None
|
||||||
self.data = None
|
self.data = None
|
||||||
self.json = None
|
self.json = None
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
# init template
|
# init template
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
@ -241,6 +244,7 @@ class Executor:
|
|||||||
"params": self.params,
|
"params": self.params,
|
||||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
}
|
}
|
||||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||||
try:
|
try:
|
||||||
|
@ -52,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"retry_config": {
|
||||||
|
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||||
|
"retry_interval": 0.5 * (2**2),
|
||||||
|
"retry_enabled": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult:
|
def _run(self) -> NodeRunResult:
|
||||||
@ -61,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
node_data=self.node_data,
|
node_data=self.node_data,
|
||||||
timeout=self._get_request_timeout(self.node_data),
|
timeout=self._get_request_timeout(self.node_data),
|
||||||
variable_pool=self.graph_runtime_state.variable_pool,
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
|
max_retries=0,
|
||||||
)
|
)
|
||||||
process_data["request"] = http_executor.to_log()
|
process_data["request"] = http_executor.to_log()
|
||||||
|
|
||||||
response = http_executor.invoke()
|
response = http_executor.invoke()
|
||||||
files = self.extract_files(url=http_executor.url, response=response)
|
files = self.extract_files(url=http_executor.url, response=response)
|
||||||
if not response.response.is_success and self.should_continue_on_error:
|
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
outputs={
|
outputs={
|
||||||
|
@ -29,6 +29,7 @@ workflow_run_for_list_fields = {
|
|||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
"exceptions_count": fields.Integer,
|
"exceptions_count": fields.Integer,
|
||||||
|
"retry_index": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
advanced_chat_workflow_run_for_list_fields = {
|
advanced_chat_workflow_run_for_list_fields = {
|
||||||
@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = {
|
|||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
"exceptions_count": fields.Integer,
|
"exceptions_count": fields.Integer,
|
||||||
|
"retry_index": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
advanced_chat_workflow_run_pagination_fields = {
|
advanced_chat_workflow_run_pagination_fields = {
|
||||||
@ -79,6 +81,17 @@ workflow_run_detail_fields = {
|
|||||||
"exceptions_count": fields.Integer,
|
"exceptions_count": fields.Integer,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
retry_event_field = {
|
||||||
|
"error": fields.String,
|
||||||
|
"retry_index": fields.Integer,
|
||||||
|
"inputs": fields.Raw(attribute="inputs"),
|
||||||
|
"elapsed_time": fields.Float,
|
||||||
|
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
|
||||||
|
"status": fields.String,
|
||||||
|
"outputs": fields.Raw(attribute="outputs"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
workflow_run_node_execution_fields = {
|
workflow_run_node_execution_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"index": fields.Integer,
|
"index": fields.Integer,
|
||||||
@ -99,6 +112,7 @@ workflow_run_node_execution_fields = {
|
|||||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||||
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
||||||
"finished_at": TimestampField,
|
"finished_at": TimestampField,
|
||||||
|
"retry_events": fields.List(fields.Nested(retry_event_field)),
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_run_node_execution_list_fields = {
|
workflow_run_node_execution_list_fields = {
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
"""add retry_index field to node-execution model
|
||||||
|
|
||||||
|
Revision ID: 348cb0a93d53
|
||||||
|
Revises: cf8f4fc45278
|
||||||
|
Create Date: 2024-12-16 01:23:13.093432
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '348cb0a93d53'
|
||||||
|
down_revision = 'cf8f4fc45278'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('retry_index')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum):
|
|||||||
SUCCEEDED = "succeeded"
|
SUCCEEDED = "succeeded"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
EXCEPTION = "exception"
|
EXCEPTION = "exception"
|
||||||
|
RETRY = "retry"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
|
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
|
||||||
@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model):
|
|||||||
created_by_role = db.Column(db.String(255), nullable=False)
|
created_by_role = db.Column(db.String(255), nullable=False)
|
||||||
created_by = db.Column(StringUUID, nullable=False)
|
created_by = db.Column(StringUUID, nullable=False)
|
||||||
finished_at = db.Column(db.DateTime)
|
finished_at = db.Column(db.DateTime)
|
||||||
|
retry_index = db.Column(db.Integer, server_default=db.text("0"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
|
@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData
|
|||||||
from core.workflow.nodes.base.node import BaseNode
|
from core.workflow.nodes.base.node import BaseNode
|
||||||
from core.workflow.nodes.enums import ErrorStrategy
|
from core.workflow.nodes.enums import ErrorStrategy
|
||||||
from core.workflow.nodes.event import RunCompletedEvent
|
from core.workflow.nodes.event import RunCompletedEvent
|
||||||
|
from core.workflow.nodes.event.event import SingleStepRetryEvent
|
||||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
@ -220,56 +221,99 @@ class WorkflowService:
|
|||||||
|
|
||||||
# run draft workflow node
|
# run draft workflow node
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
retries = 0
|
||||||
|
max_retries = 0
|
||||||
|
should_retry = True
|
||||||
|
retry_events = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
node_instance, generator = WorkflowEntry.single_step_run(
|
while retries <= max_retries and should_retry:
|
||||||
workflow=draft_workflow,
|
retry_start_at = time.perf_counter()
|
||||||
node_id=node_id,
|
node_instance, generator = WorkflowEntry.single_step_run(
|
||||||
user_inputs=user_inputs,
|
workflow=draft_workflow,
|
||||||
user_id=account.id,
|
node_id=node_id,
|
||||||
)
|
user_inputs=user_inputs,
|
||||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
user_id=account.id,
|
||||||
node_run_result: NodeRunResult | None = None
|
)
|
||||||
for event in generator:
|
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||||
if isinstance(event, RunCompletedEvent):
|
max_retries = (
|
||||||
node_run_result = event.run_result
|
node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0
|
||||||
|
)
|
||||||
|
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||||
|
node_run_result: NodeRunResult | None = None
|
||||||
|
for event in generator:
|
||||||
|
if isinstance(event, RunCompletedEvent):
|
||||||
|
node_run_result = event.run_result
|
||||||
|
|
||||||
# sign output files
|
# sign output files
|
||||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||||
break
|
break
|
||||||
|
|
||||||
if not node_run_result:
|
if not node_run_result:
|
||||||
raise ValueError("Node run failed with no run result")
|
raise ValueError("Node run failed with no run result")
|
||||||
# single step debug mode error handling return
|
# single step debug mode error handling return
|
||||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
node_error_args = {
|
if (
|
||||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
retries == max_retries
|
||||||
"error": node_run_result.error,
|
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||||
"inputs": node_run_result.inputs,
|
and node_run_result.outputs
|
||||||
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
and not node_instance.should_continue_on_error
|
||||||
}
|
):
|
||||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
node_run_result = NodeRunResult(
|
should_retry = False
|
||||||
**node_error_args,
|
else:
|
||||||
outputs={
|
if node_instance.should_retry:
|
||||||
**node_instance.node_data.default_value_dict,
|
node_run_result.status = WorkflowNodeExecutionStatus.RETRY
|
||||||
"error_message": node_run_result.error,
|
retries += 1
|
||||||
"error_type": node_run_result.error_type,
|
node_run_result.retry_index = retries
|
||||||
},
|
retry_events.append(
|
||||||
)
|
SingleStepRetryEvent(
|
||||||
else:
|
inputs=WorkflowEntry.handle_special_values(node_run_result.inputs)
|
||||||
node_run_result = NodeRunResult(
|
if node_run_result.inputs
|
||||||
**node_error_args,
|
else None,
|
||||||
outputs={
|
error=node_run_result.error,
|
||||||
"error_message": node_run_result.error,
|
outputs=WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||||
"error_type": node_run_result.error_type,
|
if node_run_result.outputs
|
||||||
},
|
else None,
|
||||||
)
|
retry_index=node_run_result.retry_index,
|
||||||
run_succeeded = node_run_result.status in (
|
elapsed_time=time.perf_counter() - retry_start_at,
|
||||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata)
|
||||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
if node_run_result.metadata
|
||||||
)
|
else None,
|
||||||
error = node_run_result.error if not run_succeeded else None
|
)
|
||||||
|
)
|
||||||
|
time.sleep(retry_interval)
|
||||||
|
else:
|
||||||
|
should_retry = False
|
||||||
|
if node_instance.should_continue_on_error:
|
||||||
|
node_error_args = {
|
||||||
|
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
"error": node_run_result.error,
|
||||||
|
"inputs": node_run_result.inputs,
|
||||||
|
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||||
|
}
|
||||||
|
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||||
|
node_run_result = NodeRunResult(
|
||||||
|
**node_error_args,
|
||||||
|
outputs={
|
||||||
|
**node_instance.node_data.default_value_dict,
|
||||||
|
"error_message": node_run_result.error,
|
||||||
|
"error_type": node_run_result.error_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
node_run_result = NodeRunResult(
|
||||||
|
**node_error_args,
|
||||||
|
outputs={
|
||||||
|
"error_message": node_run_result.error,
|
||||||
|
"error_type": node_run_result.error_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
run_succeeded = node_run_result.status in (
|
||||||
|
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||||
|
)
|
||||||
|
error = node_run_result.error if not run_succeeded else None
|
||||||
except WorkflowNodeRunFailedError as e:
|
except WorkflowNodeRunFailedError as e:
|
||||||
node_instance = e.node_instance
|
node_instance = e.node_instance
|
||||||
run_succeeded = False
|
run_succeeded = False
|
||||||
@ -318,6 +362,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
db.session.add(workflow_node_execution)
|
db.session.add(workflow_node_execution)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
workflow_node_execution.retry_events = retry_events
|
||||||
|
|
||||||
return workflow_node_execution
|
return workflow_node_execution
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunSucceededEvent,
|
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
@ -14,7 +13,9 @@ from models.workflow import WorkflowType
|
|||||||
|
|
||||||
class ContinueOnErrorTestHelper:
|
class ContinueOnErrorTestHelper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
|
def get_code_node(
|
||||||
|
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
|
||||||
|
):
|
||||||
"""Helper method to create a code node configuration"""
|
"""Helper method to create a code node configuration"""
|
||||||
node = {
|
node = {
|
||||||
"id": "node",
|
"id": "node",
|
||||||
@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
|
|||||||
"code_language": "python3",
|
"code_language": "python3",
|
||||||
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||||
"type": "code",
|
"type": "code",
|
||||||
|
**retry_config,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if default_value:
|
if default_value:
|
||||||
@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_http_node(
|
def get_http_node(
|
||||||
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
|
error_strategy: str = "fail-branch",
|
||||||
|
default_value: dict | None = None,
|
||||||
|
authorization_success: bool = False,
|
||||||
|
retry_config: dict = {},
|
||||||
):
|
):
|
||||||
"""Helper method to create a http node configuration"""
|
"""Helper method to create a http node configuration"""
|
||||||
authorization = (
|
authorization = (
|
||||||
@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
|
|||||||
"body": None,
|
"body": None,
|
||||||
"type": "http-request",
|
"type": "http-request",
|
||||||
"error_strategy": error_strategy,
|
"error_strategy": error_strategy,
|
||||||
|
**retry_config,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if default_value:
|
if default_value:
|
||||||
|
73
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
73
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from core.workflow.graph_engine.entities.event import (
|
||||||
|
GraphRunFailedEvent,
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
NodeRunRetryEvent,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
|
||||||
|
|
||||||
|
DEFAULT_VALUE_EDGE = [
|
||||||
|
{
|
||||||
|
"id": "start-source-node-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "node",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-source-answer-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "answer",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_default_value_partial_success():
|
||||||
|
"""retry default value node with partial success status"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(
|
||||||
|
"default-value",
|
||||||
|
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||||
|
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||||
|
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||||
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||||
|
assert len(events) == 11
|
||||||
|
|
||||||
|
|
||||||
|
def test_retry_failed():
|
||||||
|
"""retry failed with success status"""
|
||||||
|
error_code = """
|
||||||
|
def main() -> dict:
|
||||||
|
return {
|
||||||
|
"result": 1 / 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||||
|
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||||
|
assert len(events) == 8
|
Loading…
x
Reference in New Issue
Block a user