diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index 5b96f280dc..05d8f1765f 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.graph_engine.entities.run_condition import RunCondition @@ -10,15 +9,15 @@ class RunConditionHandler(ABC): @abstractmethod def check(self, - graph_node: "GraphNode", - graph_runtime_state: "GraphRuntimeState", - predecessor_node_result: NodeRunResult) -> bool: + source_node_id: str, + target_node_id: str, + graph: "Graph") -> bool: """ Check if the condition can be executed - :param graph_node: graph node - :param graph_runtime_state: graph runtime state - :param predecessor_node_result: predecessor node result + :param source_node_id: source node id + :param target_node_id: target node id + :param graph: graph :return: bool """ raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index 90cc035a4f..000547b5cc 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -1,25 +1,29 @@ -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler class BranchIdentifyRunConditionHandler(RunConditionHandler): def check(self, - graph_node: "GraphNode", - graph_runtime_state: "GraphRuntimeState", - predecessor_node_result: NodeRunResult) -> bool: + source_node_id: str, + target_node_id: str, + graph: "Graph") -> bool: """ Check if the condition can be executed - :param graph_node: graph node - :param graph_runtime_state: graph runtime state - :param predecessor_node_result: predecessor node result + :param source_node_id: source node id + :param target_node_id: target node id + :param graph: graph :return: bool """ if not self.condition.branch_identify: raise Exception("Branch identify is required") - if not predecessor_node_result.edge_source_handle: + run_state = graph.run_state + node_route_result = run_state.node_route_results.get(source_node_id) + if not node_route_result: return False - return self.condition.branch_identify == predecessor_node_result.edge_source_handle + if not node_route_result.edge_source_handle: + return False + + return self.condition.branch_identify == node_route_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index cca7fc235e..c71438cf89 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -1,19 +1,18 @@ -from core.workflow.entities.node_entities import NodeRunResult from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): def check(self, - graph_node: "GraphNode", - graph_runtime_state: "GraphRuntimeState", - predecessor_node_result: NodeRunResult) -> bool: + source_node_id: str, + target_node_id: str, + graph: "Graph") -> bool: """ Check if the condition can be executed - :param graph_node: graph node - :param graph_runtime_state: graph runtime state - :param predecessor_node_result: predecessor node result + :param source_node_id: source node id + :param target_node_id: target node id + :param graph: graph :return: bool """ if not self.condition.conditions: @@ -22,7 +21,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() compare_result, _ = condition_processor.process( - variable_pool=graph_runtime_state.variable_pool, + variable_pool=graph.run_state.variable_pool, logical_operator="and", conditions=self.condition.conditions ) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index e456471bd9..dbff4fba65 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,9 +1,11 @@ +import uuid from typing import Optional, cast from pydantic import BaseModel, Field from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.run_condition import RunCondition @@ -18,6 +20,14 @@ class GraphEdge(BaseModel): """condition to run the edge""" +class GraphParallel(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """random uuid parallel id""" + + parent_parallel_id: Optional[str] = None + """parent parallel id if exists""" + + class GraphStateRoute(BaseModel): route_id: str """route id""" @@ -28,13 +38,21 @@ class GraphStateRoute(BaseModel): class GraphState(BaseModel): routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict) - """graph state routes (route_id: run_result)""" + """graph state routes (source_node_id: routes)""" variable_pool: VariablePool """variable pool""" node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict) - """node results in route (route_id: run_result)""" + """node results in route (node_id: run_result)""" + + +class NextGraphNode(BaseModel): + node_id: str + """next node id""" + + parallel: Optional[GraphParallel] = None + """parallel""" class Graph(BaseModel): @@ -45,7 +63,13 @@ class Graph(BaseModel): """graph node ids""" edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict) - """graph edge mapping""" + """graph edge mapping (source node id: edges)""" + + parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict) + """graph parallel mapping (parallel id: parallel)""" + + node_parallel_mapping: dict[str, str] = Field(default_factory=dict) + """graph node parallel mapping (node id: parallel id)""" run_state: GraphState """graph run state""" @@ -139,6 +163,16 @@ class Graph(BaseModel): node_id=root_node_id ) + # init parallel mapping + parallel_mapping: dict[str, GraphParallel] = {} + node_parallel_mapping: dict[str, str] = {} + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + start_node_id=root_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping + ) + # init graph graph = cls( root_node_id=root_node_id, @@ -146,7 +180,9 @@ class Graph(BaseModel): edge_mapping=edge_mapping, run_state=GraphState( variable_pool=variable_pool - ) + ), + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping ) return graph @@ -173,12 +209,48 @@ class Graph(BaseModel): edge_mapping=edge_mapping, node_id=graph_edge.target_node_id ) - def next_node_ids(self) -> list[str]: + + def next_node_ids(self) -> list[NextGraphNode]: """ Get next node ids """ - # todo - return [] + # get current node ids in state + if not self.run_state.routes: + return [NextGraphNode(node_id=self.root_node_id)] + + route_final_graph_edges: list[GraphEdge] = [] + for route in self.run_state.routes[self.root_node_id]: + graph_edges = self.edge_mapping.get(route.node_id) + if not graph_edges: + continue + + for edge in graph_edges: + if edge.target_node_id not in self.run_state.routes: + route_final_graph_edges.append(edge) + + next_graph_nodes = [] + for route_final_graph_edge in route_final_graph_edges: + node_id = route_final_graph_edge.target_node_id + # check condition + if route_final_graph_edge.run_condition: + result = ConditionManager.get_condition_handler( + run_condition=route_final_graph_edge.run_condition + ).check( + source_node_id=route_final_graph_edge.source_node_id, + target_node_id=route_final_graph_edge.target_node_id, + graph=self + ) + + if not result: + continue + + parallel = None + if route_final_graph_edge.target_node_id in self.node_parallel_mapping: + parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]] + + next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel)) + + return next_graph_nodes def add_extra_edge(self, source_node_id: str, target_node_id: str, @@ -222,3 +294,163 @@ class Graph(BaseModel): leaf_node_ids.append(node_id) return leaf_node_ids + + @classmethod + def _recursively_add_parallels(cls, + edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + parallel_mapping: dict[str, GraphParallel], + node_parallel_mapping: dict[str, str]) -> None: + """ + Recursively add parallel ids + + :param edge_mapping: edge mapping + :param start_node_id: start from node id + :param parallel_mapping: parallel mapping + :param node_parallel_mapping: node parallel mapping + """ + target_node_edges = edge_mapping.get(start_node_id, []) + if len(target_node_edges) > 1: + # fetch all node ids in current parallels + parallel_node_ids = [graph_edge.target_node_id + for graph_edge in target_node_edges if graph_edge.run_condition is not None] + + # any target node id in node_parallel_mapping + if parallel_node_ids: + # all parallel_node_ids in node_parallel_mapping + parent_parallel_id = None + if all(node_id in node_parallel_mapping for node_id in parallel_node_ids): + parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]] + + parallel = GraphParallel(parent_parallel_id=parent_parallel_id) + parallel_mapping[parallel.id] = parallel + + in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( + edge_mapping=edge_mapping, + parallel_node_ids=parallel_node_ids + ) + + node_parallel_mapping.update({node_id: parallel.id for node_id in in_branch_node_ids}) + + for graph_edge in target_node_edges: + cls._recursively_add_parallels( + edge_mapping=edge_mapping, + start_node_id=graph_edge.target_node_id, + parallel_mapping=parallel_mapping, + node_parallel_mapping=node_parallel_mapping + ) + + @classmethod + def _recursively_add_parallel_node_ids(cls, + branch_node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + merge_node_id: str, + start_node_id: str) -> None: + """ + Recursively add node ids + + :param branch_node_ids: in branch node ids + :param edge_mapping: edge mapping + :param merge_node_id: merge node id + :param start_node_id: start node id + """ + for graph_edge in edge_mapping.get(start_node_id, []): + if (graph_edge.target_node_id != merge_node_id + and graph_edge.target_node_id not in branch_node_ids): + branch_node_ids.append(graph_edge.target_node_id) + cls._recursively_add_parallel_node_ids( + branch_node_ids=branch_node_ids, + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=graph_edge.target_node_id + ) + + @classmethod + def _fetch_all_node_ids_in_parallels(cls, + edge_mapping: dict[str, list[GraphEdge]], + parallel_node_ids: list[str]) -> dict[str, list[str]]: + """ + Fetch all node ids in parallels + """ + routes_node_ids: dict[str, list[str]] = {} + for parallel_node_id in parallel_node_ids: + routes_node_ids[parallel_node_id] = [] + + # fetch routes node ids + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=parallel_node_id, + routes_node_ids=routes_node_ids[parallel_node_id] + ) + + # fetch leaf node ids from routes node ids + leaf_node_ids: dict[str, list[str]] = {} + merge_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + for node_id in node_ids: + if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0: + if branch_node_id not in leaf_node_ids: + leaf_node_ids[branch_node_id] = [] + + leaf_node_ids[branch_node_id].append(node_id) + + for branch_node_id2, inner_route2 in routes_node_ids.items(): + if branch_node_id != branch_node_id2 and node_id in inner_route2: + if node_id not in merge_branch_node_ids: + merge_branch_node_ids[node_id] = [] + + merge_branch_node_ids[node_id].append(branch_node_id2) + + # sorted merge_branch_node_ids by branch_node_ids length desc + merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True)) + + branches_merge_node_ids: dict[str, str] = {} + for node_id, branch_node_ids in merge_branch_node_ids.items(): + if len(branch_node_ids) <= 1: + continue + + for branch_node_id in branch_node_ids: + if branch_node_id in branches_merge_node_ids: + continue + + branches_merge_node_ids[branch_node_id] = node_id + + in_branch_node_ids: dict[str, list[str]] = {} + for branch_node_id, node_ids in routes_node_ids.items(): + in_branch_node_ids[branch_node_id] = [branch_node_id] + if branch_node_id not in branches_merge_node_ids: + # all node ids in current branch is in this thread + in_branch_node_ids[branch_node_id].extend(node_ids) + else: + merge_node_id = branches_merge_node_ids[branch_node_id] + # fetch all node ids from branch_node_id and merge_node_id + cls._recursively_add_parallel_node_ids( + branch_node_ids=in_branch_node_ids[branch_node_id], + edge_mapping=edge_mapping, + merge_node_id=merge_node_id, + start_node_id=branch_node_id + ) + + return in_branch_node_ids + + @classmethod + def _recursively_fetch_routes(cls, + edge_mapping: dict[str, list[GraphEdge]], + start_node_id: str, + routes_node_ids: list[str]) -> None: + """ + Recursively fetch route + """ + if start_node_id not in edge_mapping: + return + + for graph_edge in edge_mapping[start_node_id]: + # find next node ids + if graph_edge.target_node_id not in routes_node_ids: + routes_node_ids.append(graph_edge.target_node_id) + + cls._recursively_fetch_routes( + edge_mapping=edge_mapping, + start_node_id=graph_edge.target_node_id, + routes_node_ids=routes_node_ids + )