diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index d8ad1dbd49..ba6ba16e36 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor): for answer_node_id, route_position in self.route_position.items(): if answer_node_id not in self.rest_node_ids: continue - # exclude current node id + # Remove current node id from answer dependencies to support stream output if it is a success branch answer_dependencies = self.generate_routes.answer_dependencies - if event.node_id in answer_dependencies[answer_node_id]: + edge_mapping = self.graph.edge_mapping.get(event.node_id) + success_edge = ( + next( + ( + edge + for edge in edge_mapping + if edge.run_condition + and edge.run_condition.type == "branch_identify" + and edge.run_condition.branch_identify == "success-branch" + ), + None, + ) + if edge_mapping + else None + ) + if ( + event.node_id in answer_dependencies[answer_node_id] + and success_edge + and success_edge.target_node_id == answer_node_id + ): answer_dependencies[answer_node_id].remove(event.node_id) answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) # all depends on answer node id not in rest node ids diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ed35d8a32a..f28853e8fd 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -1,14 +1,20 @@ +from unittest.mock import patch + from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, NodeRunExceptionEvent, + NodeRunFailedEvent, NodeRunStreamChunkEvent, ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine +from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom -from models.workflow import WorkflowType +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType class ContinueOnErrorTestHelper: @@ -489,11 +495,16 @@ def test_variable_pool_error_type_variable(): def test_no_node_in_fail_branch_continue_on_error(): """Test HTTP node with fail-branch error strategy""" graph_config = { - "edges": FAIL_BRANCH_EDGES[:-1], + "edges": FAIL_BRANCH_EDGES + + [{"id": "fail-source-answer-target", "source": "node", "target": "answer", "sourceHandle": "source"}], "nodes": [ {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, { - "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", + }, + { + "data": {"title": "success", "type": "answer", "answer": "{{#node.query#}}"}, "id": "success", }, ContinueOnErrorTestHelper.get_http_node(), @@ -506,3 +517,47 @@ def test_no_node_in_fail_branch_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 + + +def test_stream_output_with_fail_branch_continue_on_error(): + """Test stream output with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_llm_node(), + ], + } + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + + def llm_generator(self): + contents = ["hi", "bye", "good morning"] + + yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"]) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={}, + process_data={}, + outputs={}, + metadata={ + NodeRunMetadataKey.TOTAL_TOKENS: 1, + NodeRunMetadataKey.TOTAL_PRICE: 1, + NodeRunMetadataKey.CURRENCY: "USD", + }, + ) + ) + + with patch.object(LLMNode, "_run", new=llm_generator): + events = list(graph_engine.run()) + assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1 + assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)