fix(fail-branch): prevent streaming output in exception branches

This commit is contained in:
Novice 2025-03-31 10:01:04 +08:00
parent 46d235bca0
commit e42ab9f002
2 changed files with 79 additions and 5 deletions

View File

@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items(): for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids: if answer_node_id not in self.rest_node_ids:
continue continue
# exclude current node id # Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]: edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id) answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids

View File

@ -1,14 +1,20 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class ContinueOnErrorTestHelper: class ContinueOnErrorTestHelper:
@ -489,11 +495,16 @@ def test_variable_pool_error_type_variable():
def test_no_node_in_fail_branch_continue_on_error(): def test_no_node_in_fail_branch_continue_on_error():
"""Test HTTP node with fail-branch error strategy""" """Test HTTP node with fail-branch error strategy"""
graph_config = { graph_config = {
"edges": FAIL_BRANCH_EDGES[:-1], "edges": FAIL_BRANCH_EDGES
+ [{"id": "fail-source-answer-target", "source": "node", "target": "answer", "sourceHandle": "source"}],
"nodes": [ "nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{ {
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "success", "type": "answer", "answer": "{{#node.query#}}"},
"id": "success", "id": "success",
}, },
ContinueOnErrorTestHelper.get_http_node(), ContinueOnErrorTestHelper.get_http_node(),
@ -506,3 +517,47 @@ def test_no_node_in_fail_branch_continue_on_error():
assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)