mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 19:29:13 +08:00
fix(graph_engine): fix execute loops in parallel
This commit is contained in:
parent
4418fa1d2b
commit
74c8004944
@ -357,21 +357,42 @@ class Graph(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# collect all branches node ids
|
# collect all branches node ids
|
||||||
for branch_node_id, node_ids in in_branch_node_ids.items():
|
parallel_node_ids = []
|
||||||
|
for _, node_ids in in_branch_node_ids.items():
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
|
parallel_node_ids.append(node_id)
|
||||||
node_parallel_mapping[node_id] = parallel.id
|
node_parallel_mapping[node_id] = parallel.id
|
||||||
|
|
||||||
end_to_node_id: Optional[str] = None
|
outside_parallel_target_node_ids = set()
|
||||||
for node_id in node_parallel_mapping:
|
for node_id in parallel_node_ids:
|
||||||
node_edges = edge_mapping.get(node_id)
|
if node_id == parallel.start_from_node_id:
|
||||||
if not end_to_node_id and node_edges and len(node_edges) == 1:
|
continue
|
||||||
target_node_id = node_edges[0].target_node_id
|
|
||||||
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
|
||||||
end_to_node_id = target_node_id
|
|
||||||
break
|
|
||||||
|
|
||||||
if end_to_node_id:
|
node_edges = edge_mapping.get(node_id)
|
||||||
parallel.end_to_node_id = end_to_node_id
|
if not node_edges:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(node_edges) > 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_node_id = node_edges[0].target_node_id
|
||||||
|
if target_node_id in parallel_node_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if parent_parallel_id:
|
||||||
|
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||||
|
if not parent_parallel:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (
|
||||||
|
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||||
|
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||||
|
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||||
|
):
|
||||||
|
outside_parallel_target_node_ids.add(target_node_id)
|
||||||
|
|
||||||
|
if len(outside_parallel_target_node_ids) == 1:
|
||||||
|
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||||
|
|
||||||
for graph_edge in target_node_edges:
|
for graph_edge in target_node_edges:
|
||||||
cls._recursively_add_parallels(
|
cls._recursively_add_parallels(
|
||||||
|
@ -270,10 +270,10 @@ class GraphEngine:
|
|||||||
|
|
||||||
next_node_id = edge.target_node_id
|
next_node_id = edge.target_node_id
|
||||||
else:
|
else:
|
||||||
|
final_node_id = None
|
||||||
|
|
||||||
if any(edge.run_condition for edge in edge_mappings):
|
if any(edge.run_condition for edge in edge_mappings):
|
||||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||||
final_node_id = None
|
|
||||||
|
|
||||||
condition_edge_mappings = {}
|
condition_edge_mappings = {}
|
||||||
for edge in edge_mappings:
|
for edge in edge_mappings:
|
||||||
if edge.run_condition:
|
if edge.run_condition:
|
||||||
@ -331,13 +331,15 @@ class GraphEngine:
|
|||||||
|
|
||||||
for item in parallel_generator:
|
for item in parallel_generator:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
next_node_id = item
|
final_node_id = item
|
||||||
else:
|
else:
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
if not next_node_id:
|
if not final_node_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
next_node_id = final_node_id
|
||||||
|
|
||||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user