diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 6dcd0a52f8..16bf80f34a 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -174,6 +174,7 @@ class Graph(BaseModel): node_parallel_mapping: dict[str, str] = {} cls._recursively_add_parallels( edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, start_node_id=root_node_id, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping @@ -310,6 +311,7 @@ class Graph(BaseModel): @classmethod def _recursively_add_parallels(cls, edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], start_node_id: str, parallel_mapping: dict[str, GraphParallel], node_parallel_mapping: dict[str, str]) -> None: @@ -365,6 +367,7 @@ class Graph(BaseModel): in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, parallel_branch_node_ids=parallel_branch_node_ids ) @@ -412,6 +415,7 @@ class Graph(BaseModel): for graph_edge in target_node_edges: cls._recursively_add_parallels( edge_mapping=edge_mapping, + reverse_edge_mapping=reverse_edge_mapping, start_node_id=graph_edge.target_node_id, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping @@ -472,6 +476,7 @@ class Graph(BaseModel): @classmethod def _fetch_all_node_ids_in_parallels(cls, edge_mapping: dict[str, list[GraphEdge]], + reverse_edge_mapping: dict[str, list[GraphEdge]], parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: """ Fetch all node ids in parallels @@ -499,7 +504,11 @@ class Graph(BaseModel): leaf_node_ids[branch_node_id].append(node_id) for branch_node_id2, inner_route2 in routes_node_ids.items(): - if branch_node_id != branch_node_id2 and node_id in inner_route2: + if ( + branch_node_id != branch_node_id2 + and node_id in inner_route2 + and len(reverse_edge_mapping.get(node_id, [])) > 1 + ): if node_id not in merge_branch_node_ids: merge_branch_node_ids[node_id] = []