fix(fail-branch): prevent streaming output in exception branches (#17153)

This commit is contained in:
Novice 2025-04-16 22:34:07 +08:00 committed by GitHub
parent 44cdb3dcea
commit c91045a9d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 7 deletions

View File

@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items(): for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids: if answer_node_id not in self.rest_node_ids:
continue continue
# exclude current node id # Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]: edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id) answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids

View File

@ -1,14 +1,20 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class ContinueOnErrorTestHelper: class ContinueOnErrorTestHelper:
@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
"edges": FAIL_BRANCH_EDGES[:-1], "edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [ "nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{ {"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(), ContinueOnErrorTestHelper.get_http_node(),
], ],
} }
@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)