fix(workflow): duplicate nodes in parallel

This commit is contained in:
takatost 2024-08-28 19:01:45 +08:00
parent 5d34e080eb
commit 790dd3b22f
2 changed files with 20 additions and 5 deletions

View File

@ -335,8 +335,10 @@ class Graph(BaseModel):
if parallel_node_ids:
# all parallel_node_ids in node_parallel_mapping
parent_parallel_id = None
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
for node_id in parallel_node_ids:
if node_id in node_parallel_mapping:
parent_parallel_id = node_parallel_mapping[node_id]
break
parent_parallel = None
if parent_parallel_id:
@ -392,7 +394,10 @@ class Graph(BaseModel):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
for graph_edge in target_node_edges:
cls._recursively_add_parallels(
@ -436,7 +441,7 @@ class Graph(BaseModel):
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_node_id in parallel_node_ids:
routes_node_ids[parallel_node_id] = []
routes_node_ids[parallel_node_id] = [parallel_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
@ -479,12 +484,16 @@ class Graph(BaseModel):
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = [branch_node_id]
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],

View File

@ -366,6 +366,12 @@ class GraphEngine:
# new thread
for edge in edge_mappings:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
):
continue
thread = threading.Thread(target=self._run_parallel_node, kwargs={
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,