fix(workflow): run node in multi parallel bugs

This commit is contained in:
takatost 2024-08-30 18:55:21 +08:00
parent e3295181d2
commit 77e62f7fee

View File

@ -179,6 +179,15 @@ class Graph(BaseModel):
node_parallel_mapping=node_parallel_mapping node_parallel_mapping=node_parallel_mapping
) )
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=3,
parent_parallel_id=parallel.parent_parallel_id
)
# init answer stream generate routes # init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping, node_id_config_mapping=node_id_config_mapping,
@ -315,11 +324,11 @@ class Graph(BaseModel):
target_node_edges = edge_mapping.get(start_node_id, []) target_node_edges = edge_mapping.get(start_node_id, [])
if len(target_node_edges) > 1: if len(target_node_edges) > 1:
# fetch all node ids in current parallels # fetch all node ids in current parallels
parallel_node_ids = [] parallel_branch_node_ids = []
condition_edge_mappings = {} condition_edge_mappings = {}
for graph_edge in target_node_edges: for graph_edge in target_node_edges:
if graph_edge.run_condition is None: if graph_edge.run_condition is None:
parallel_node_ids.append(graph_edge.target_node_id) parallel_branch_node_ids.append(graph_edge.target_node_id)
else: else:
condition_hash = graph_edge.run_condition.hash condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings: if not condition_hash in condition_edge_mappings:
@ -330,13 +339,13 @@ class Graph(BaseModel):
for _, graph_edges in condition_edge_mappings.items(): for _, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1: if len(graph_edges) > 1:
for graph_edge in graph_edges: for graph_edge in graph_edges:
parallel_node_ids.append(graph_edge.target_node_id) parallel_branch_node_ids.append(graph_edge.target_node_id)
# any target node id in node_parallel_mapping # any target node id in node_parallel_mapping
if parallel_node_ids: if parallel_branch_node_ids:
# all parallel_node_ids in node_parallel_mapping # all parallel_branch_node_ids in node_parallel_mapping
parent_parallel_id = None parent_parallel_id = None
for node_id in parallel_node_ids: for node_id in parallel_branch_node_ids:
if node_id in node_parallel_mapping: if node_id in node_parallel_mapping:
parent_parallel_id = node_parallel_mapping[node_id] parent_parallel_id = node_parallel_mapping[node_id]
break break
@ -356,7 +365,7 @@ class Graph(BaseModel):
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping, edge_mapping=edge_mapping,
parallel_node_ids=parallel_node_ids parallel_branch_node_ids=parallel_branch_node_ids
) )
# collect all branches node ids # collect all branches node ids
@ -408,6 +417,33 @@ class Graph(BaseModel):
node_parallel_mapping=node_parallel_mapping node_parallel_mapping=node_parallel_mapping
) )
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level
)
@classmethod @classmethod
def _recursively_add_parallel_node_ids(cls, def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str], branch_node_ids: list[str],
@ -436,19 +472,19 @@ class Graph(BaseModel):
@classmethod @classmethod
def _fetch_all_node_ids_in_parallels(cls, def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]], edge_mapping: dict[str, list[GraphEdge]],
parallel_node_ids: list[str]) -> dict[str, list[str]]: parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
""" """
Fetch all node ids in parallels Fetch all node ids in parallels
""" """
routes_node_ids: dict[str, list[str]] = {} routes_node_ids: dict[str, list[str]] = {}
for parallel_node_id in parallel_node_ids: for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_node_id] = [parallel_node_id] routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids # fetch routes node ids
cls._recursively_fetch_routes( cls._recursively_fetch_routes(
edge_mapping=edge_mapping, edge_mapping=edge_mapping,
start_node_id=parallel_node_id, start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_node_id] routes_node_ids=routes_node_ids[parallel_branch_node_id]
) )
# fetch leaf node ids from routes node ids # fetch leaf node ids from routes node ids
@ -472,6 +508,31 @@ class Graph(BaseModel):
# sorted merge_branch_node_ids by branch_node_ids length desc # sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
if node_id not in merge_branch_node_ids or node_id2 not in branch_node_ids:
continue
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
del merge_branch_node_ids[node_id]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
del merge_branch_node_ids[node_id2]
branches_merge_node_ids: dict[str, str] = {} branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items(): for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1: if len(branch_node_ids) <= 1:
@ -526,3 +587,29 @@ class Graph(BaseModel):
start_node_id=graph_edge.target_node_id, start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids routes_node_ids=routes_node_ids
) )
@classmethod
def _is_node2_after_node1(
cls,
node1_id: str,
node2_id: str,
edge_mapping: dict[str, list[GraphEdge]]
) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id,
node2_id=node2_id,
edge_mapping=edge_mapping
):
return True
return False