From 790dd3b22f7740f25a890d5244b7d095aa686294 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 28 Aug 2024 19:01:45 +0800 Subject: [PATCH] fix(workflow): duplicate nodes in parallel --- .../workflow/graph_engine/entities/graph.py | 19 ++++++++++++++----- .../workflow/graph_engine/graph_engine.py | 6 ++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 748b49bdb8..f6f8bf2717 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -335,8 +335,10 @@ class Graph(BaseModel): if parallel_node_ids: # all parallel_node_ids in node_parallel_mapping parent_parallel_id = None - if all(node_id in node_parallel_mapping for node_id in parallel_node_ids): - parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]] + for node_id in parallel_node_ids: + if node_id in node_parallel_mapping: + parent_parallel_id = node_parallel_mapping[node_id] + break parent_parallel = None if parent_parallel_id: @@ -392,7 +394,10 @@ class Graph(BaseModel): outside_parallel_target_node_ids.add(target_node_id) if len(outside_parallel_target_node_ids) == 1: - parallel.end_to_node_id = outside_parallel_target_node_ids.pop() + if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id: + parallel.end_to_node_id = None + else: + parallel.end_to_node_id = outside_parallel_target_node_ids.pop() for graph_edge in target_node_edges: cls._recursively_add_parallels( @@ -436,7 +441,7 @@ class Graph(BaseModel): """ routes_node_ids: dict[str, list[str]] = {} for parallel_node_id in parallel_node_ids: - routes_node_ids[parallel_node_id] = [] + routes_node_ids[parallel_node_id] = [parallel_node_id] # fetch routes node ids cls._recursively_fetch_routes( @@ -479,12 +484,16 @@ class Graph(BaseModel): in_branch_node_ids: dict[str, list[str]] = {} for branch_node_id, node_ids in routes_node_ids.items(): - in_branch_node_ids[branch_node_id] = [branch_node_id] + in_branch_node_ids[branch_node_id] = [] if branch_node_id not in branches_merge_node_ids: # all node ids in current branch is in this thread + in_branch_node_ids[branch_node_id].append(branch_node_id) in_branch_node_ids[branch_node_id].extend(node_ids) else: merge_node_id = branches_merge_node_ids[branch_node_id] + if merge_node_id != branch_node_id: + in_branch_node_ids[branch_node_id].append(branch_node_id) + # fetch all node ids from branch_node_id and merge_node_id cls._recursively_add_parallel_node_ids( branch_node_ids=in_branch_node_ids[branch_node_id], diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 92db8cfcca..f0538e2671 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -366,6 +366,12 @@ class GraphEngine: # new thread for edge in edge_mappings: + if ( + edge.target_node_id not in self.graph.node_parallel_mapping + or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id + ): + continue + thread = threading.Thread(target=self._run_parallel_node, kwargs={ 'flask_app': current_app._get_current_object(), # type: ignore[attr-defined] 'q': q,