mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
fix(fail-branch): prevent streaming output in exception branches (#17153)
This commit is contained in:
parent
44cdb3dcea
commit
c91045a9d0
@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
|
|||||||
for answer_node_id, route_position in self.route_position.items():
|
for answer_node_id, route_position in self.route_position.items():
|
||||||
if answer_node_id not in self.rest_node_ids:
|
if answer_node_id not in self.rest_node_ids:
|
||||||
continue
|
continue
|
||||||
# exclude current node id
|
# Remove current node id from answer dependencies to support stream output if it is a success branch
|
||||||
answer_dependencies = self.generate_routes.answer_dependencies
|
answer_dependencies = self.generate_routes.answer_dependencies
|
||||||
if event.node_id in answer_dependencies[answer_node_id]:
|
edge_mapping = self.graph.edge_mapping.get(event.node_id)
|
||||||
|
success_edge = (
|
||||||
|
next(
|
||||||
|
(
|
||||||
|
edge
|
||||||
|
for edge in edge_mapping
|
||||||
|
if edge.run_condition
|
||||||
|
and edge.run_condition.type == "branch_identify"
|
||||||
|
and edge.run_condition.branch_identify == "success-branch"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if edge_mapping
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
event.node_id in answer_dependencies[answer_node_id]
|
||||||
|
and success_edge
|
||||||
|
and success_edge.target_node_id == answer_node_id
|
||||||
|
):
|
||||||
answer_dependencies[answer_node_id].remove(event.node_id)
|
answer_dependencies[answer_node_id].remove(event.node_id)
|
||||||
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
|
||||||
# all depends on answer node id not in rest node ids
|
# all depends on answer node id not in rest node ids
|
||||||
|
@ -1,14 +1,20 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
|
NodeRunFailedEvent,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
|
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
|
|
||||||
class ContinueOnErrorTestHelper:
|
class ContinueOnErrorTestHelper:
|
||||||
@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
|
|||||||
"edges": FAIL_BRANCH_EDGES[:-1],
|
"edges": FAIL_BRANCH_EDGES[:-1],
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
{
|
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
|
||||||
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
|
||||||
"id": "success",
|
|
||||||
},
|
|
||||||
ContinueOnErrorTestHelper.get_http_node(),
|
ContinueOnErrorTestHelper.get_http_node(),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
|
|||||||
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
||||||
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_output_with_fail_branch_continue_on_error():
|
||||||
|
"""Test stream output with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_llm_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
|
||||||
|
def llm_generator(self):
|
||||||
|
contents = ["hi", "bye", "good morning"]
|
||||||
|
|
||||||
|
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
|
||||||
|
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={},
|
||||||
|
process_data={},
|
||||||
|
outputs={},
|
||||||
|
metadata={
|
||||||
|
NodeRunMetadataKey.TOTAL_TOKENS: 1,
|
||||||
|
NodeRunMetadataKey.TOTAL_PRICE: 1,
|
||||||
|
NodeRunMetadataKey.CURRENCY: "USD",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(LLMNode, "_run", new=llm_generator):
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
|
||||||
|
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user