mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-20 12:49:14 +08:00
fix(graph_engine): fix execute loops in parallel
This commit is contained in:
parent
4418fa1d2b
commit
74c8004944
@ -357,21 +357,42 @@ class Graph(BaseModel):
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
for branch_node_id, node_ids in in_branch_node_ids.items():
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
end_to_node_id: Optional[str] = None
|
||||
for node_id in node_parallel_mapping:
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not end_to_node_id and node_edges and len(node_edges) == 1:
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
||||
end_to_node_id = target_node_id
|
||||
break
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
if end_to_node_id:
|
||||
parallel.end_to_node_id = end_to_node_id
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
cls._recursively_add_parallels(
|
||||
|
@ -270,10 +270,10 @@ class GraphEngine:
|
||||
|
||||
next_node_id = edge.target_node_id
|
||||
else:
|
||||
final_node_id = None
|
||||
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
final_node_id = None
|
||||
|
||||
condition_edge_mappings = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
@ -331,13 +331,15 @@ class GraphEngine:
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
next_node_id = item
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
if not next_node_id:
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user