From 7abc7fa5735fb9f168ba47e4a79a60373072c6af Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Fri, 20 Dec 2024 14:14:06 +0800 Subject: [PATCH] Feat: Retry on node execution errors (#11871) Co-authored-by: Novice Lee --- .../advanced_chat/generate_task_pipeline.py | 17 + .../apps/workflow/generate_task_pipeline.py | 19 +- api/core/app/apps/workflow_app_runner.py | 32 ++ api/core/app/entities/queue_entities.py | 32 ++ api/core/app/entities/task_entities.py | 70 ++++ .../task_pipeline/workflow_cycle_manage.py | 93 ++++++ api/core/helper/ssrf_proxy.py | 1 - api/core/workflow/entities/node_entities.py | 3 + .../workflow/graph_engine/entities/event.py | 7 + .../workflow/graph_engine/graph_engine.py | 308 ++++++++++-------- api/core/workflow/nodes/base/entities.py | 13 + api/core/workflow/nodes/base/node.py | 11 +- api/core/workflow/nodes/enums.py | 1 + api/core/workflow/nodes/event/__init__.py | 9 +- api/core/workflow/nodes/event/event.py | 25 ++ .../workflow/nodes/http_request/executor.py | 4 + api/core/workflow/nodes/http_request/node.py | 8 +- api/fields/workflow_run_fields.py | 14 + ...dd_retry_index_field_to_node_execution_.py | 33 ++ api/models/workflow.py | 2 + api/services/workflow_service.py | 137 +++++--- .../workflow/nodes/test_continue_on_error.py | 12 +- .../core/workflow/nodes/test_retry.py | 73 +++++ 23 files changed, 736 insertions(+), 188 deletions(-) create mode 100644 api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_retry.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 32a23a7fdb..ce0e959627 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -22,6 +22,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_node_execution=workflow_node_execution, ) + if response: + yield response + elif isinstance( + event, + QueueNodeRetryEvent, + ): + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + if response: yield response elif isinstance(event, QueueParallelBranchRunStartedEvent): diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2330229f43..79e5e2bcb9 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -18,6 +18,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, ) - if node_failed_response: yield node_failed_response + elif isinstance( + event, + QueueNodeRetryEvent, + ): + workflow_node_execution = self._handle_workflow_node_execution_retried( + workflow_run=workflow_run, event=event + ) + + response = self._workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + elif isinstance(event, QueueParallelBranchRunStartedEvent): if not workflow_run: raise Exception("Workflow run not initialized.") diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 97c2cc5bb9..2fbf711175 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner): error=event.error if isinstance(event, IterationRunFailedEvent) else None, ) ) + elif isinstance(event, NodeRunRetryEvent): + self._publish_event( + QueueNodeRetryEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.error, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + start_index=event.start_index, + ) + ) def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5b2036c7f9..49b7e80246 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -43,6 +43,7 @@ class QueueEvent(StrEnum): ERROR = "error" PING = "ping" STOP = "stop" + RETRY = "retry" class AppQueueEvent(BaseModel): @@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): iteration_duration_map: Optional[dict[str, float]] = None +class QueueNodeRetryEvent(AppQueueEvent): + """QueueNodeRetryEvent entity""" + + event: QueueEvent = QueueEvent.RETRY + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + retry_index: int # retry index + start_index: int # start index + + class QueueNodeInIterationFailedEvent(AppQueueEvent): """ QueueNodeInIterationFailedEvent entity diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7fe06b3af8..dd088a8978 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -52,6 +52,7 @@ class StreamEvent(Enum): WORKFLOW_FINISHED = "workflow_finished" NODE_STARTED = "node_started" NODE_FINISHED = "node_finished" + NODE_RETRY = "node_retry" PARALLEL_BRANCH_STARTED = "parallel_branch_started" PARALLEL_BRANCH_FINISHED = "parallel_branch_finished" ITERATION_STARTED = "iteration_started" @@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse): } +class NodeRetryStreamResponse(StreamResponse): + """ + NodeFinishStreamResponse entity + """ + + class Data(BaseModel): + """ + Data entity + """ + + id: str + node_id: str + node_type: str + title: str + index: int + predecessor_node_id: Optional[str] = None + inputs: Optional[dict] = None + process_data: Optional[dict] = None + outputs: Optional[dict] = None + status: str + error: Optional[str] = None + elapsed_time: float + execution_metadata: Optional[dict] = None + created_at: int + finished_at: int + files: Optional[Sequence[Mapping[str, Any]]] = [] + parallel_id: Optional[str] = None + parallel_start_node_id: Optional[str] = None + parent_parallel_id: Optional[str] = None + parent_parallel_start_node_id: Optional[str] = None + iteration_id: Optional[str] = None + retry_index: int = 0 + + event: StreamEvent = StreamEvent.NODE_RETRY + workflow_run_id: str + data: Data + + def to_ignore_detail_dict(self): + return { + "event": self.event.value, + "task_id": self.task_id, + "workflow_run_id": self.workflow_run_id, + "data": { + "id": self.data.id, + "node_id": self.data.node_id, + "node_type": self.data.node_type, + "title": self.data.title, + "index": self.data.index, + "predecessor_node_id": self.data.predecessor_node_id, + "inputs": None, + "process_data": None, + "outputs": None, + "status": self.data.status, + "error": None, + "elapsed_time": self.data.elapsed_time, + "execution_metadata": None, + "created_at": self.data.created_at, + "finished_at": self.data.finished_at, + "files": [], + "parallel_id": self.data.parallel_id, + "parallel_start_node_id": self.data.parallel_start_node_id, + "parent_parallel_id": self.data.parent_parallel_id, + "parent_parallel_start_node_id": self.data.parent_parallel_start_node_id, + "iteration_id": self.data.iteration_id, + "retry_index": self.data.retry_index, + }, + } + + class ParallelBranchStartStreamResponse(StreamResponse): """ ParallelBranchStartStreamResponse entity diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 8b588b48eb..e2fa12b1cd 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -15,6 +15,7 @@ from core.app.entities.queue_entities import ( QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, + QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -26,6 +27,7 @@ from core.app.entities.task_entities import ( IterationNodeNextStreamResponse, IterationNodeStartStreamResponse, NodeFinishStreamResponse, + NodeRetryStreamResponse, NodeStartStreamResponse, ParallelBranchFinishedStreamResponse, ParallelBranchStartStreamResponse, @@ -423,6 +425,52 @@ class WorkflowCycleManage: return workflow_node_execution + def _handle_workflow_node_execution_retried( + self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + ) -> WorkflowNodeExecution: + """ + Workflow node execution failed + :param event: queue node failed event + :return: + """ + created_at = event.start_at + finished_at = datetime.now(UTC).replace(tzinfo=None) + elapsed_time = (finished_at - created_at).total_seconds() + inputs = WorkflowEntry.handle_special_values(event.inputs) + outputs = WorkflowEntry.handle_special_values(event.outputs) + + workflow_node_execution = WorkflowNodeExecution() + workflow_node_execution.tenant_id = workflow_run.tenant_id + workflow_node_execution.app_id = workflow_run.app_id + workflow_node_execution.workflow_id = workflow_run.workflow_id + workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + workflow_node_execution.workflow_run_id = workflow_run.id + workflow_node_execution.node_execution_id = event.node_execution_id + workflow_node_execution.node_id = event.node_id + workflow_node_execution.node_type = event.node_type.value + workflow_node_execution.title = event.node_data.title + workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value + workflow_node_execution.created_by_role = workflow_run.created_by_role + workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.created_at = created_at + workflow_node_execution.finished_at = finished_at + workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.error = event.error + workflow_node_execution.inputs = json.dumps(inputs) if inputs else None + workflow_node_execution.outputs = json.dumps(outputs) if outputs else None + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) + workflow_node_execution.index = event.start_index + + db.session.add(workflow_node_execution) + db.session.commit() + db.session.refresh(workflow_node_execution) + + return workflow_node_execution + ################################################# # to stream responses # ################################################# @@ -587,6 +635,51 @@ class WorkflowCycleManage: ), ) + def _workflow_node_retry_to_stream_response( + self, + event: QueueNodeRetryEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeFinishStreamResponse]: + """ + Workflow node finish to stream response. + :param event: queue node succeeded or failed event + :param task_id: task id + :param workflow_node_execution: workflow node execution + :return: + """ + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + return None + + return NodeRetryStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_run_id, + data=NodeRetryStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs_dict, + process_data=workflow_node_execution.process_data_dict, + outputs=workflow_node_execution.outputs_dict, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.execution_metadata_dict, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + retry_index=event.retry_index, + ), + ) + def _workflow_parallel_branch_start_to_stream_response( self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index ef4516b404..d8aa805364 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): ) retries = 0 - stream = kwargs.pop("stream", False) while retries <= max_retries: try: if dify_config.SSRF_PROXY_ALL_URL: diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 976a5ef74e..ca01dcd7d8 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -45,3 +45,6 @@ class NodeRunResult(BaseModel): error: Optional[str] = None # error message if status is failed error_type: Optional[str] = None # error type if status is failed + + # single step node run retry + retry_index: int = 0 diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 73450349de..9997153164 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunRetryEvent(BaseNodeEvent): + error: str = Field(..., description="error") + retry_index: int = Field(..., description="which retry attempt is about to be performed") + start_at: datetime = Field(..., description="retry start time") + start_index: int = Field(..., description="retry start index") + + ########################################### # Parallel Branch Events ########################################### diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 034b4bd399..e292b09968 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,6 +5,7 @@ import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy +from datetime import UTC, datetime from typing import Any, Optional, cast from flask import Flask, current_app @@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -581,7 +583,7 @@ class GraphEngine: def _run_node( self, - node_instance: BaseNode, + node_instance: BaseNode[BaseNodeData], route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -607,36 +609,121 @@ class GraphEngine: ) db.session.close() + max_retries = node_instance.node_data.retry_config.max_retries + retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + retries = 0 + shoudl_continue_retry = True + while shoudl_continue_retry and retries <= max_retries: + try: + # run node + retry_start_at = datetime.now(UTC).replace(tzinfo=None) + generator = node_instance.run() + for item in generator: + if isinstance(item, GraphEngineEvent): + if isinstance(item, BaseIterationEvent): + # add parallel info to iteration event + item.parallel_id = parallel_id + item.parallel_start_node_id = parallel_start_node_id + item.parent_parallel_id = parent_parallel_id + item.parent_parallel_start_node_id = parent_parallel_start_node_id - try: - # run node - generator = node_instance.run() - for item in generator: - if isinstance(item, GraphEngineEvent): - if isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id + yield item + else: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if ( + retries == max_retries + and node_instance.node_type == NodeType.HTTP_REQUEST + and run_result.outputs + and not node_instance.should_continue_on_error + ): + run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + if node_instance.should_retry and retries < max_retries: + retries += 1 + self.graph_runtime_state.node_run_steps += 1 + route_node_state.node_run_result = run_result + yield NodeRunRetryEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + error=run_result.error, + retry_index=retries, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + start_at=retry_start_at, + start_index=self.graph_runtime_state.node_run_steps, + ) + time.sleep(retry_interval) + continue + route_node_state.set_finished(run_result=run_result) - yield item - else: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result - route_node_state.set_finished(run_result=run_result) + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + if node_instance.should_continue_on_error: + # if run failed, handle error + run_result = self._handle_continue_on_error( + node_instance, + item.run_result, + self.graph_runtime_state.variable_pool, + handle_exceptions=handle_exceptions, + ) + route_node_state.node_run_result = run_result + route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + yield NodeRunExceptionEvent( + error=run_result.error or "System Error", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + shoudl_continue_retry = False + else: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + shoudl_continue_retry = False + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if node_instance.should_continue_on_error and self.graph.edge_mapping.get( + node_instance.node_id + ): + run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + ) - if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.should_continue_on_error: - # if run failed, handle error - run_result = self._handle_continue_on_error( - node_instance, - item.run_result, - self.graph_runtime_state.variable_pool, - handle_exceptions=handle_exceptions, - ) - route_node_state.node_run_result = run_result - route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.llm_usage: + # use the latest usage + self.graph_runtime_state.llm_usage += run_result.llm_usage + + # append node output variables to variable pool if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively @@ -645,21 +732,23 @@ class GraphEngine: variable_key_list=[variable_key], variable_value=variable_value, ) - yield NodeRunExceptionEvent( - error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - else: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or "Unknown error.", + + # add parallel info to run result metadata + if parallel_id and parallel_start_node_id: + if not run_result.metadata: + run_result.metadata = {} + + run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id + run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = ( + parallel_start_node_id + ) + if parent_parallel_id and parent_parallel_start_node_id: + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id + run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( + parent_parallel_start_node_id + ) + + yield NodeRunSucceededEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, @@ -670,108 +759,59 @@ class GraphEngine: parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) + shoudl_continue_retry = False - elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - if node_instance.should_continue_on_error and self.graph.edge_mapping.get( - node_instance.node_id - ): - run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS - if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): - # plus state total_tokens - self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] - ) - - if run_result.llm_usage: - # use the latest usage - self.graph_runtime_state.llm_usage += run_result.llm_usage - - # append node output variables to variable pool - if run_result.outputs: - for variable_key, variable_value in run_result.outputs.items(): - # append variables to variable pool recursively - self._append_variables_recursively( - node_id=node_instance.node_id, - variable_key_list=[variable_key], - variable_value=variable_value, - ) - - # add parallel info to run result metadata - if parallel_id and parallel_start_node_id: - if not run_result.metadata: - run_result.metadata = {} - - run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id - if parent_parallel_id and parent_parallel_start_node_id: - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( - parent_parallel_start_node_id - ) - - yield NodeRunSucceededEvent( + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - - break - elif isinstance(item, RunStreamChunkEvent): - yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - elif isinstance(item, RunRetrieverResourceEvent): - yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - except GenerateTaskStoppedError: - # trigger node run failed event - route_node_state.status = RouteNodeState.Status.FAILED - route_node_state.failed_reason = "Workflow stopped." - yield NodeRunFailedEvent( - error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - return - except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") - raise e - finally: - db.session.close() + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + retriever_resources=item.retriever_resources, + context=item.context, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + except GenerateTaskStoppedError: + # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." + yield NodeRunFailedEvent( + error="Workflow stopped.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + return + except Exception as e: + logger.exception(f"Node {node_instance.node_data.title} run failed") + raise e + finally: + db.session.close() def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 9271867aff..529fd7be74 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -106,12 +106,25 @@ class DefaultValue(BaseModel): return self +class RetryConfig(BaseModel): + """node retry config""" + + max_retries: int = 0 # max retry times + retry_interval: int = 0 # retry interval in milliseconds + retry_enabled: bool = False # whether retry is enabled + + @property + def retry_interval_seconds(self) -> float: + return self.retry_interval / 1000 + + class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None version: str = "1" + retry_config: RetryConfig = RetryConfig() @property def default_value_dict(self): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index ee274c2501..b799e74266 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from models.workflow import WorkflowNodeExecutionStatus @@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]): bool: if should continue on error """ return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + + @property + def should_retry(self) -> bool: + """judge if should retry + + Returns: + bool: if should retry + """ + return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 6d8ca6f701..32fdc048d1 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum): CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] +RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST] diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py index 5e3b31e48b..08c47d5e57 100644 --- a/api/core/workflow/nodes/event/__init__.py +++ b/api/core/workflow/nodes/event/__init__.py @@ -1,4 +1,10 @@ -from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from .event import ( + ModelInvokeCompletedEvent, + RunCompletedEvent, + RunRetrieverResourceEvent, + RunRetryEvent, + RunStreamChunkEvent, +) from .types import NodeEvent __all__ = [ @@ -6,5 +12,6 @@ __all__ = [ "NodeEvent", "RunCompletedEvent", "RunRetrieverResourceEvent", + "RunRetryEvent", "RunStreamChunkEvent", ] diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index b7034561bf..0dc35e7d77 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,7 +1,10 @@ +from datetime import datetime + from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus class RunCompletedEvent(BaseModel): @@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel): text: str usage: LLMUsage finish_reason: str | None = None + + +class RunRetryEvent(BaseModel): + """Node Run Retry event""" + + error: str = Field(..., description="error") + retry_index: int = Field(..., description="Retry attempt number") + start_at: datetime = Field(..., description="Retry start time") + + +class SingleStepRetryEvent(BaseModel): + """Single step retry event""" + + status: str = WorkflowNodeExecutionStatus.RETRY.value + + inputs: dict | None = Field(..., description="input") + error: str = Field(..., description="error") + outputs: dict = Field(..., description="output") + retry_index: int = Field(..., description="Retry attempt number") + error: str = Field(..., description="error") + elapsed_time: float = Field(..., description="elapsed time") + execution_metadata: dict | None = Field(..., description="execution metadata") diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 90251c27a8..92f190091b 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -45,6 +45,7 @@ class Executor: headers: dict[str, str] auth: HttpRequestNodeAuthorization timeout: HttpRequestNodeTimeout + max_retries: int boundary: str @@ -54,6 +55,7 @@ class Executor: node_data: HttpRequestNodeData, timeout: HttpRequestNodeTimeout, variable_pool: VariablePool, + max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES, ): # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": @@ -73,6 +75,7 @@ class Executor: self.files = None self.data = None self.json = None + self.max_retries = max_retries # init template self.variable_pool = variable_pool @@ -241,6 +244,7 @@ class Executor: "params": self.params, "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), "follow_redirects": True, + "max_retries": self.max_retries, } # request_args = {k: v for k, v in request_args.items() if v is not None} try: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a0fc8acaef..171389a34c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -52,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): "max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, }, }, + "retry_config": { + "max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES, + "retry_interval": 0.5 * (2**2), + "retry_enabled": True, + }, } def _run(self) -> NodeRunResult: @@ -61,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): node_data=self.node_data, timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, + max_retries=0, ) process_data["request"] = http_executor.to_log() response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and self.should_continue_on_error: + if not response.response.is_success and (self.should_continue_on_error or self.should_retry): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs={ diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 8390c66556..7c01ffc2c6 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -29,6 +29,7 @@ workflow_run_for_list_fields = { "created_at": TimestampField, "finished_at": TimestampField, "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_for_list_fields = { @@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = { "created_at": TimestampField, "finished_at": TimestampField, "exceptions_count": fields.Integer, + "retry_index": fields.Integer, } advanced_chat_workflow_run_pagination_fields = { @@ -79,6 +81,17 @@ workflow_run_detail_fields = { "exceptions_count": fields.Integer, } +retry_event_field = { + "error": fields.String, + "retry_index": fields.Integer, + "inputs": fields.Raw(attribute="inputs"), + "elapsed_time": fields.Float, + "execution_metadata": fields.Raw(attribute="execution_metadata_dict"), + "status": fields.String, + "outputs": fields.Raw(attribute="outputs"), +} + + workflow_run_node_execution_fields = { "id": fields.String, "index": fields.Integer, @@ -99,6 +112,7 @@ workflow_run_node_execution_fields = { "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "finished_at": TimestampField, + "retry_events": fields.List(fields.Nested(retry_event_field)), } workflow_run_node_execution_list_fields = { diff --git a/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py new file mode 100644 index 0000000000..eef8224dea --- /dev/null +++ b/api/migrations/versions/2024_12_16_0123-348cb0a93d53_add_retry_index_field_to_node_execution_.py @@ -0,0 +1,33 @@ +"""add retry_index field to node-execution model + +Revision ID: 348cb0a93d53 +Revises: cf8f4fc45278 +Create Date: 2024-12-16 01:23:13.093432 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '348cb0a93d53' +down_revision = 'cf8f4fc45278' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op: + batch_op.drop_column('retry_index') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index 4ddf2e9082..e933382a84 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum): SUCCEEDED = "succeeded" FAILED = "failed" EXCEPTION = "exception" + RETRY = "retry" @classmethod def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": @@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model): created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) + retry_index = db.Column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 84768d5af0..ead552d6c2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.event.event import SingleStepRetryEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -220,56 +221,99 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() + retries = 0 + max_retries = 0 + should_retry = True + retry_events = [] try: - node_instance, generator = WorkflowEntry.single_step_run( - workflow=draft_workflow, - node_id=node_id, - user_inputs=user_inputs, - user_id=account.id, - ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) - node_run_result: NodeRunResult | None = None - for event in generator: - if isinstance(event, RunCompletedEvent): - node_run_result = event.run_result + while retries <= max_retries and should_retry: + retry_start_at = time.perf_counter() + node_instance, generator = WorkflowEntry.single_step_run( + workflow=draft_workflow, + node_id=node_id, + user_inputs=user_inputs, + user_id=account.id, + ) + node_instance = cast(BaseNode[BaseNodeData], node_instance) + max_retries = ( + node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0 + ) + retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + node_run_result: NodeRunResult | None = None + for event in generator: + if isinstance(event, RunCompletedEvent): + node_run_result = event.run_result - # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) - break + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) + break - if not node_run_result: - raise ValueError("Node run failed with no run result") - # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: - node_error_args = { - "status": WorkflowNodeExecutionStatus.EXCEPTION, - "error": node_run_result.error, - "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, - } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - **node_instance.node_data.default_value_dict, - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) - else: - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) - run_succeeded = node_run_result.status in ( - WorkflowNodeExecutionStatus.SUCCEEDED, - WorkflowNodeExecutionStatus.EXCEPTION, - ) - error = node_run_result.error if not run_succeeded else None + if not node_run_result: + raise ValueError("Node run failed with no run result") + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED: + if ( + retries == max_retries + and node_instance.node_type == NodeType.HTTP_REQUEST + and node_run_result.outputs + and not node_instance.should_continue_on_error + ): + node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + should_retry = False + else: + if node_instance.should_retry: + node_run_result.status = WorkflowNodeExecutionStatus.RETRY + retries += 1 + node_run_result.retry_index = retries + retry_events.append( + SingleStepRetryEvent( + inputs=WorkflowEntry.handle_special_values(node_run_result.inputs) + if node_run_result.inputs + else None, + error=node_run_result.error, + outputs=WorkflowEntry.handle_special_values(node_run_result.outputs) + if node_run_result.outputs + else None, + retry_index=node_run_result.retry_index, + elapsed_time=time.perf_counter() - retry_start_at, + execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata) + if node_run_result.metadata + else None, + ) + ) + time.sleep(retry_interval) + else: + should_retry = False + if node_instance.should_continue_on_error: + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) + error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: node_instance = e.node_instance run_succeeded = False @@ -318,6 +362,7 @@ class WorkflowService: db.session.add(workflow_node_execution) db.session.commit() + workflow_node_execution.retry_events = retry_events return workflow_node_execution diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ba209e4020..2d74be9da9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, NodeRunExceptionEvent, NodeRunStreamChunkEvent, ) @@ -14,7 +13,9 @@ from models.workflow import WorkflowType class ContinueOnErrorTestHelper: @staticmethod - def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): + def get_code_node( + code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {} + ): """Helper method to create a code node configuration""" node = { "id": "node", @@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper: "code_language": "python3", "code": "\n".join([line[4:] for line in code.split("\n")]), "type": "code", + **retry_config, }, } if default_value: @@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper: @staticmethod def get_http_node( - error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False + error_strategy: str = "fail-branch", + default_value: dict | None = None, + authorization_success: bool = False, + retry_config: dict = {}, ): """Helper method to create a http node configuration""" authorization = ( @@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper: "body": None, "type": "http-request", "error_strategy": error_strategy, + **retry_config, }, } if default_value: diff --git a/api/tests/unit_tests/core/workflow/nodes/test_retry.py b/api/tests/unit_tests/core/workflow/nodes/test_retry.py new file mode 100644 index 0000000000..c232875ce5 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_retry.py @@ -0,0 +1,73 @@ +from core.workflow.graph_engine.entities.event import ( + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunSucceededEvent, + NodeRunRetryEvent, +) +from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + + +def test_retry_default_value_partial_success(): + """retry default value node with partial success status""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", + [{"key": "result", "type": "string", "value": "http node got error response"}], + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert events[-1].outputs == {"answer": "http node got error response"} + assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events) + assert len(events) == 11 + + +def test_retry_failed(): + """retry failed with success status""" + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + None, + None, + retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}}, + ), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2 + assert any(isinstance(e, GraphRunFailedEvent) for e in events) + assert len(events) == 8