fix(workflow): in multi-parallel execution with multiple conditional branches (#8221)

This commit is contained in:
takatost 2024-09-10 21:09:18 +08:00 committed by GitHub
parent ffd4bf8bf0
commit f6dfe23cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -304,11 +304,14 @@ class Graph(BaseModel):
parallel = None parallel = None
if len(target_node_edges) > 1: if len(target_node_edges) > 1:
# fetch all node ids in current parallels # fetch all node ids in current parallels
parallel_branch_node_ids = [] parallel_branch_node_ids = {}
condition_edge_mappings = {} condition_edge_mappings = {}
for graph_edge in target_node_edges: for graph_edge in target_node_edges:
if graph_edge.run_condition is None: if graph_edge.run_condition is None:
parallel_branch_node_ids.append(graph_edge.target_node_id) if "default" not in parallel_branch_node_ids:
parallel_branch_node_ids["default"] = []
parallel_branch_node_ids["default"].append(graph_edge.target_node_id)
else: else:
condition_hash = graph_edge.run_condition.hash condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings: if not condition_hash in condition_edge_mappings:
@ -316,13 +319,19 @@ class Graph(BaseModel):
condition_edge_mappings[condition_hash].append(graph_edge) condition_edge_mappings[condition_hash].append(graph_edge)
for _, graph_edges in condition_edge_mappings.items(): for condition_hash, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1: if len(graph_edges) > 1:
for graph_edge in graph_edges: if condition_hash not in parallel_branch_node_ids:
parallel_branch_node_ids.append(graph_edge.target_node_id) parallel_branch_node_ids[condition_hash] = []
for graph_edge in graph_edges:
parallel_branch_node_ids[condition_hash].append(graph_edge.target_node_id)
condition_parallels = {}
for condition_hash, condition_parallel_branch_node_ids in parallel_branch_node_ids.items():
# any target node id in node_parallel_mapping # any target node id in node_parallel_mapping
if parallel_branch_node_ids: parallel = None
if condition_parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel( parallel = GraphParallel(
@ -331,11 +340,12 @@ class Graph(BaseModel):
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None, parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
) )
parallel_mapping[parallel.id] = parallel parallel_mapping[parallel.id] = parallel
condition_parallels[condition_hash] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping, edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping, reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids, parallel_branch_node_ids=condition_parallel_branch_node_ids,
) )
# collect all branches node ids # collect all branches node ids
@ -399,7 +409,69 @@ class Graph(BaseModel):
else: else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop() parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
if condition_edge_mappings:
for condition_hash, graph_edges in condition_edge_mappings.items():
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=condition_parallels.get(condition_hash),
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges: for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
else:
for graph_edge in target_node_edges:
current_parallel = cls._get_current_parallel(
parallel_mapping=parallel_mapping,
graph_edge=graph_edge,
parallel=parallel,
parent_parallel=parent_parallel,
)
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
@classmethod
def _get_current_parallel(
cls,
parallel_mapping: dict[str, GraphParallel],
graph_edge: GraphEdge,
parallel: Optional[GraphParallel] = None,
parent_parallel: Optional[GraphParallel] = None,
) -> Optional[GraphParallel]:
"""
Get current parallel
"""
current_parallel = None current_parallel = None
if parallel: if parallel:
current_parallel = parallel current_parallel = parallel
@ -422,14 +494,7 @@ class Graph(BaseModel):
): ):
current_parallel = parent_parallel_parent_parallel current_parallel = parent_parallel_parent_parallel
cls._recursively_add_parallels( return current_parallel
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel,
)
@classmethod @classmethod
def _check_exceed_parallel_limit( def _check_exceed_parallel_limit(