fix(graph_engine): fix execute loops in parallel

This commit is contained in:
takatost 2024-08-28 17:42:32 +08:00
parent 4418fa1d2b
commit 74c8004944
2 changed files with 38 additions and 15 deletions

View File

@ -357,21 +357,42 @@ class Graph(BaseModel):
) )
# collect all branches node ids # collect all branches node ids
for branch_node_id, node_ids in in_branch_node_ids.items(): parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids: for node_id in node_ids:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id node_parallel_mapping[node_id] = parallel.id
end_to_node_id: Optional[str] = None outside_parallel_target_node_ids = set()
for node_id in node_parallel_mapping: for node_id in parallel_node_ids:
node_edges = edge_mapping.get(node_id) if node_id == parallel.start_from_node_id:
if not end_to_node_id and node_edges and len(node_edges) == 1: continue
target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
break
if end_to_node_id: node_edges = edge_mapping.get(node_id)
parallel.end_to_node_id = end_to_node_id if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
for graph_edge in target_node_edges: for graph_edge in target_node_edges:
cls._recursively_add_parallels( cls._recursively_add_parallels(

View File

@ -270,10 +270,10 @@ class GraphEngine:
next_node_id = edge.target_node_id next_node_id = edge.target_node_id
else: else:
final_node_id = None
if any(edge.run_condition for edge in edge_mappings): if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results # if nodes has run conditions, get node id which branch to take based on the run condition results
final_node_id = None
condition_edge_mappings = {} condition_edge_mappings = {}
for edge in edge_mappings: for edge in edge_mappings:
if edge.run_condition: if edge.run_condition:
@ -331,13 +331,15 @@ class GraphEngine:
for item in parallel_generator: for item in parallel_generator:
if isinstance(item, str): if isinstance(item, str):
next_node_id = item final_node_id = item
else: else:
yield item yield item
if not next_node_id: if not final_node_id:
break break
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id: if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break break