fix(graph_engine): fix execute loops in parallel

This commit is contained in:
takatost 2024-08-28 17:42:32 +08:00
parent 4418fa1d2b
commit 74c8004944
2 changed files with 38 additions and 15 deletions

View File

@ -357,21 +357,42 @@ class Graph(BaseModel):
)
# collect all branches node ids
for branch_node_id, node_ids in in_branch_node_ids.items():
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
end_to_node_id: Optional[str] = None
for node_id in node_parallel_mapping:
node_edges = edge_mapping.get(node_id)
if not end_to_node_id and node_edges and len(node_edges) == 1:
target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
break
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
if end_to_node_id:
parallel.end_to_node_id = end_to_node_id
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
for graph_edge in target_node_edges:
cls._recursively_add_parallels(

View File

@ -270,10 +270,10 @@ class GraphEngine:
next_node_id = edge.target_node_id
else:
final_node_id = None
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
final_node_id = None
condition_edge_mappings = {}
for edge in edge_mappings:
if edge.run_condition:
@ -331,13 +331,15 @@ class GraphEngine:
for item in parallel_generator:
if isinstance(item, str):
next_node_id = item
final_node_id = item
else:
yield item
if not next_node_id:
if not final_node_id:
break
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break