mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-13 22:35:54 +08:00
fix(workflow): fix answer node stream processing in conditional branches (#12510)
This commit is contained in:
parent
831459b895
commit
54b5b80a07
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
|
||||||
@ -48,25 +49,35 @@ class StreamProcessor(ABC):
|
|||||||
# we remove the node maybe shortcut the answer node, so comment this code for now
|
# we remove the node maybe shortcut the answer node, so comment this code for now
|
||||||
# there is not effect on the answer node and the workflow, when we have a better solution
|
# there is not effect on the answer node and the workflow, when we have a better solution
|
||||||
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
# we can open this code. Issues: #11542 #9560 #10638 #10564
|
||||||
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
|
# ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
|
||||||
if "answer" in ids:
|
# if "answer" in ids:
|
||||||
continue
|
# continue
|
||||||
else:
|
# else:
|
||||||
reachable_node_ids.extend(ids)
|
# reachable_node_ids.extend(ids)
|
||||||
|
|
||||||
|
# The branch_identify parameter is added to ensure that
|
||||||
|
# only nodes in the correct logical branch are included.
|
||||||
|
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
|
||||||
|
reachable_node_ids.extend(ids)
|
||||||
else:
|
else:
|
||||||
unreachable_first_node_ids.append(edge.target_node_id)
|
unreachable_first_node_ids.append(edge.target_node_id)
|
||||||
|
|
||||||
for node_id in unreachable_first_node_ids:
|
for node_id in unreachable_first_node_ids:
|
||||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||||
|
|
||||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
|
||||||
node_ids = []
|
node_ids = []
|
||||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||||
if edge.target_node_id == self.graph.root_node_id:
|
if edge.target_node_id == self.graph.root_node_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Only follow edges that match the branch_identify or have no run_condition
|
||||||
|
if edge.run_condition and edge.run_condition.branch_identify:
|
||||||
|
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
|
||||||
|
continue
|
||||||
|
|
||||||
node_ids.append(edge.target_node_id)
|
node_ids.append(edge.target_node_id)
|
||||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
|
||||||
return node_ids
|
return node_ids
|
||||||
|
|
||||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user