From 77e62f7fee21d9ede4dc97012ebcee0576cb91ff Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 30 Aug 2024 18:55:21 +0800 Subject: [PATCH] fix(workflow): run node in multi parallel bugs --- .../workflow/graph_engine/entities/graph.py | 111 ++++++++++++++++-- 1 file changed, 99 insertions(+), 12 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 028483cf6e..9a79e7e630 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -179,6 +179,15 @@ class Graph(BaseModel): node_parallel_mapping=node_parallel_mapping ) + # Check if it exceeds N layers of parallel + for parallel in parallel_mapping.values(): + if parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=3, + parent_parallel_id=parallel.parent_parallel_id + ) + # init answer stream generate routes answer_stream_generate_routes = AnswerStreamGeneratorRouter.init( node_id_config_mapping=node_id_config_mapping, @@ -315,11 +324,11 @@ class Graph(BaseModel): target_node_edges = edge_mapping.get(start_node_id, []) if len(target_node_edges) > 1: # fetch all node ids in current parallels - parallel_node_ids = [] + parallel_branch_node_ids = [] condition_edge_mappings = {} for graph_edge in target_node_edges: if graph_edge.run_condition is None: - parallel_node_ids.append(graph_edge.target_node_id) + parallel_branch_node_ids.append(graph_edge.target_node_id) else: condition_hash = graph_edge.run_condition.hash if not condition_hash in condition_edge_mappings: @@ -330,13 +339,13 @@ class Graph(BaseModel): for _, graph_edges in condition_edge_mappings.items(): if len(graph_edges) > 1: for graph_edge in graph_edges: - parallel_node_ids.append(graph_edge.target_node_id) + parallel_branch_node_ids.append(graph_edge.target_node_id) # any target node id in node_parallel_mapping - if parallel_node_ids: - # all parallel_node_ids in node_parallel_mapping + if parallel_branch_node_ids: + # all parallel_branch_node_ids in node_parallel_mapping parent_parallel_id = None - for node_id in parallel_node_ids: + for node_id in parallel_branch_node_ids: if node_id in node_parallel_mapping: parent_parallel_id = node_parallel_mapping[node_id] break @@ -356,7 +365,7 @@ class Graph(BaseModel): in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( edge_mapping=edge_mapping, - parallel_node_ids=parallel_node_ids + parallel_branch_node_ids=parallel_branch_node_ids ) # collect all branches node ids @@ -408,6 +417,33 @@ class Graph(BaseModel): node_parallel_mapping=node_parallel_mapping ) + @classmethod + def _check_exceed_parallel_limit( + cls, + parallel_mapping: dict[str, GraphParallel], + level_limit: int, + parent_parallel_id: str, + current_level: int = 1 + ) -> None: + """ + Check if it exceeds N layers of parallel + """ + parent_parallel = parallel_mapping.get(parent_parallel_id) + if not parent_parallel: + return + + current_level += 1 + if current_level > level_limit: + raise ValueError(f"Exceeds {level_limit} layers of parallel") + + if parent_parallel.parent_parallel_id: + cls._check_exceed_parallel_limit( + parallel_mapping=parallel_mapping, + level_limit=level_limit, + parent_parallel_id=parent_parallel.parent_parallel_id, + current_level=current_level + ) + @classmethod def _recursively_add_parallel_node_ids(cls, branch_node_ids: list[str], @@ -436,19 +472,19 @@ class Graph(BaseModel): @classmethod def _fetch_all_node_ids_in_parallels(cls, edge_mapping: dict[str, list[GraphEdge]], - parallel_node_ids: list[str]) -> dict[str, list[str]]: + parallel_branch_node_ids: list[str]) -> dict[str, list[str]]: """ Fetch all node ids in parallels """ routes_node_ids: dict[str, list[str]] = {} - for parallel_node_id in parallel_node_ids: - routes_node_ids[parallel_node_id] = [parallel_node_id] + for parallel_branch_node_id in parallel_branch_node_ids: + routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id] # fetch routes node ids cls._recursively_fetch_routes( edge_mapping=edge_mapping, - start_node_id=parallel_node_id, - routes_node_ids=routes_node_ids[parallel_node_id] + start_node_id=parallel_branch_node_id, + routes_node_ids=routes_node_ids[parallel_branch_node_id] ) # fetch leaf node ids from routes node ids @@ -472,6 +508,31 @@ class Graph(BaseModel): # sorted merge_branch_node_ids by branch_node_ids length desc merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) + duplicate_end_node_ids = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + for node_id2, branch_node_ids2 in merge_branch_node_ids.items(): + if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2): + if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids: + duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids + + for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): + if node_id not in merge_branch_node_ids or node_id2 not in branch_node_ids: + continue + + # check which node is after + if cls._is_node2_after_node1( + node1_id=node_id, + node2_id=node_id2, + edge_mapping=edge_mapping + ): + del merge_branch_node_ids[node_id] + elif cls._is_node2_after_node1( + node1_id=node_id2, + node2_id=node_id, + edge_mapping=edge_mapping + ): + del merge_branch_node_ids[node_id2] + branches_merge_node_ids: dict[str, str] = {} for node_id, branch_node_ids in merge_branch_node_ids.items(): if len(branch_node_ids) <= 1: @@ -526,3 +587,29 @@ class Graph(BaseModel): start_node_id=graph_edge.target_node_id, routes_node_ids=routes_node_ids ) + + @classmethod + def _is_node2_after_node1( + cls, + node1_id: str, + node2_id: str, + edge_mapping: dict[str, list[GraphEdge]] + ) -> bool: + """ + is node2 after node1 + """ + if node1_id not in edge_mapping: + return False + + for graph_edge in edge_mapping[node1_id]: + if graph_edge.target_node_id == node2_id: + return True + + if cls._is_node2_after_node1( + node1_id=graph_edge.target_node_id, + node2_id=node2_id, + edge_mapping=edge_mapping + ): + return True + + return False \ No newline at end of file