mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 04:15:56 +08:00
fix(workflow): in multi-parallel execution with multiple conditional branches (#8221)
This commit is contained in:
parent
ffd4bf8bf0
commit
f6dfe23cf8
@ -304,11 +304,14 @@ class Graph(BaseModel):
|
|||||||
parallel = None
|
parallel = None
|
||||||
if len(target_node_edges) > 1:
|
if len(target_node_edges) > 1:
|
||||||
# fetch all node ids in current parallels
|
# fetch all node ids in current parallels
|
||||||
parallel_branch_node_ids = []
|
parallel_branch_node_ids = {}
|
||||||
condition_edge_mappings = {}
|
condition_edge_mappings = {}
|
||||||
for graph_edge in target_node_edges:
|
for graph_edge in target_node_edges:
|
||||||
if graph_edge.run_condition is None:
|
if graph_edge.run_condition is None:
|
||||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
if "default" not in parallel_branch_node_ids:
|
||||||
|
parallel_branch_node_ids["default"] = []
|
||||||
|
|
||||||
|
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
|
||||||
else:
|
else:
|
||||||
condition_hash = graph_edge.run_condition.hash
|
condition_hash = graph_edge.run_condition.hash
|
||||||
if not condition_hash in condition_edge_mappings:
|
if not condition_hash in condition_edge_mappings:
|
||||||
@ -316,13 +319,19 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||||
|
|
||||||
for _, graph_edges in condition_edge_mappings.items():
|
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||||
if len(graph_edges) > 1:
|
if len(graph_edges) > 1:
|
||||||
for graph_edge in graph_edges:
|
if condition_hash not in parallel_branch_node_ids:
|
||||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
parallel_branch_node_ids[condition_hash] = []
|
||||||
|
|
||||||
|
for graph_edge in graph_edges:
|
||||||
|
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
|
||||||
|
|
||||||
|
condition_parallels = {}
|
||||||
|
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
|
||||||
# any target node id in node_parallel_mapping
|
# any target node id in node_parallel_mapping
|
||||||
if parallel_branch_node_ids:
|
parallel = None
|
||||||
|
if condition_parallel_branch_node_ids:
|
||||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||||
|
|
||||||
parallel = GraphParallel(
|
parallel = GraphParallel(
|
||||||
@ -331,11 +340,12 @@ class Graph(BaseModel):
|
|||||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
|
||||||
)
|
)
|
||||||
parallel_mapping[parallel.id] = parallel
|
parallel_mapping[parallel.id] = parallel
|
||||||
|
condition_parallels[condition_hash] = parallel
|
||||||
|
|
||||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
reverse_edge_mapping=reverse_edge_mapping,
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
parallel_branch_node_ids=parallel_branch_node_ids,
|
parallel_branch_node_ids=condition_parallel_branch_node_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# collect all branches node ids
|
# collect all branches node ids
|
||||||
@ -399,7 +409,69 @@ class Graph(BaseModel):
|
|||||||
else:
|
else:
|
||||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||||
|
|
||||||
|
if condition_edge_mappings:
|
||||||
|
for condition_hash, graph_edges in condition_edge_mappings.items():
|
||||||
|
current_parallel = cls._get_current_parallel(
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
graph_edge=graph_edge,
|
||||||
|
parallel=condition_parallels.get(condition_hash),
|
||||||
|
parent_parallel=parent_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
cls._recursively_add_parallels(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
|
start_node_id=graph_edge.target_node_id,
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
node_parallel_mapping=node_parallel_mapping,
|
||||||
|
parent_parallel=current_parallel,
|
||||||
|
)
|
||||||
|
else:
|
||||||
for graph_edge in target_node_edges:
|
for graph_edge in target_node_edges:
|
||||||
|
current_parallel = cls._get_current_parallel(
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
graph_edge=graph_edge,
|
||||||
|
parallel=parallel,
|
||||||
|
parent_parallel=parent_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
cls._recursively_add_parallels(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
|
start_node_id=graph_edge.target_node_id,
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
node_parallel_mapping=node_parallel_mapping,
|
||||||
|
parent_parallel=current_parallel,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for graph_edge in target_node_edges:
|
||||||
|
current_parallel = cls._get_current_parallel(
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
graph_edge=graph_edge,
|
||||||
|
parallel=parallel,
|
||||||
|
parent_parallel=parent_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
cls._recursively_add_parallels(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
reverse_edge_mapping=reverse_edge_mapping,
|
||||||
|
start_node_id=graph_edge.target_node_id,
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
node_parallel_mapping=node_parallel_mapping,
|
||||||
|
parent_parallel=current_parallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_current_parallel(
|
||||||
|
cls,
|
||||||
|
parallel_mapping: dict[str, GraphParallel],
|
||||||
|
graph_edge: GraphEdge,
|
||||||
|
parallel: Optional[GraphParallel] = None,
|
||||||
|
parent_parallel: Optional[GraphParallel] = None,
|
||||||
|
) -> Optional[GraphParallel]:
|
||||||
|
"""
|
||||||
|
Get current parallel
|
||||||
|
"""
|
||||||
current_parallel = None
|
current_parallel = None
|
||||||
if parallel:
|
if parallel:
|
||||||
current_parallel = parallel
|
current_parallel = parallel
|
||||||
@ -422,14 +494,7 @@ class Graph(BaseModel):
|
|||||||
):
|
):
|
||||||
current_parallel = parent_parallel_parent_parallel
|
current_parallel = parent_parallel_parent_parallel
|
||||||
|
|
||||||
cls._recursively_add_parallels(
|
return current_parallel
|
||||||
edge_mapping=edge_mapping,
|
|
||||||
reverse_edge_mapping=reverse_edge_mapping,
|
|
||||||
start_node_id=graph_edge.target_node_id,
|
|
||||||
parallel_mapping=parallel_mapping,
|
|
||||||
node_parallel_mapping=node_parallel_mapping,
|
|
||||||
parent_parallel=current_parallel,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_exceed_parallel_limit(
|
def _check_exceed_parallel_limit(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user