diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 32f2859659..748b49bdb8 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -357,21 +357,42 @@ class Graph(BaseModel): ) # collect all branches node ids - for branch_node_id, node_ids in in_branch_node_ids.items(): + parallel_node_ids = [] + for _, node_ids in in_branch_node_ids.items(): for node_id in node_ids: + parallel_node_ids.append(node_id) node_parallel_mapping[node_id] = parallel.id - end_to_node_id: Optional[str] = None - for node_id in node_parallel_mapping: - node_edges = edge_mapping.get(node_id) - if not end_to_node_id and node_edges and len(node_edges) == 1: - target_node_id = node_edges[0].target_node_id - if node_parallel_mapping.get(target_node_id) == parent_parallel_id: - end_to_node_id = target_node_id - break + outside_parallel_target_node_ids = set() + for node_id in parallel_node_ids: + if node_id == parallel.start_from_node_id: + continue - if end_to_node_id: - parallel.end_to_node_id = end_to_node_id + node_edges = edge_mapping.get(node_id) + if not node_edges: + continue + + if len(node_edges) > 1: + continue + + target_node_id = node_edges[0].target_node_id + if target_node_id in parallel_node_ids: + continue + + if parent_parallel_id: + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + continue + + if ( + (node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id) + or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id) + or (not node_parallel_mapping.get(target_node_id) and not parent_parallel) + ): + outside_parallel_target_node_ids.add(target_node_id) + + if len(outside_parallel_target_node_ids) == 1: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() for graph_edge in target_node_edges: cls._recursively_add_parallels( diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8e5bb06c40..92db8cfcca 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -270,10 +270,10 @@ class GraphEngine: next_node_id = edge.target_node_id else: + final_node_id = None + if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results - final_node_id = None - condition_edge_mappings = {} for edge in edge_mappings: if edge.run_condition: @@ -331,13 +331,15 @@ class GraphEngine: for item in parallel_generator: if isinstance(item, str): - next_node_id = item + final_node_id = item else: yield item - if not next_node_id: + if not final_node_id: break + next_node_id = final_node_id + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: break