diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 16bf80f34a..48ae33f29a 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -508,11 +508,17 @@ class Graph(BaseModel): branch_node_id != branch_node_id2 and node_id in inner_route2 and len(reverse_edge_mapping.get(node_id, [])) > 1 + and cls._is_node_in_routes( + reverse_edge_mapping=reverse_edge_mapping, + start_node_id=node_id, + routes_node_ids=routes_node_ids + ) ): if node_id not in merge_branch_node_ids: merge_branch_node_ids[node_id] = [] - merge_branch_node_ids[node_id].append(branch_node_id2) + if branch_node_id2 not in merge_branch_node_ids[node_id]: + merge_branch_node_ids[node_id].append(branch_node_id2) # sorted merge_branch_node_ids by branch_node_ids length desc merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) @@ -596,6 +602,28 @@ class Graph(BaseModel): routes_node_ids=routes_node_ids ) + @classmethod + def _is_node_in_routes(cls, + reverse_edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + routes_node_ids: dict[str, list[str]]) -> bool: + """ + Recursively check if the node is in the routes + """ + if start_node_id not in reverse_edge_mapping: + return False + + all_routes_node_ids = [] + for _, node_ids in routes_node_ids.items(): + for node_id in node_ids: + all_routes_node_ids.append(node_id) + + for graph_edge in reverse_edge_mapping[start_node_id]: + if graph_edge.source_node_id not in all_routes_node_ids: + return False + + return True + @classmethod def _is_node2_after_node1( cls,