diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index aa32002668..e456471bd9 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,218 +1,224 @@ -from typing import Optional +from typing import Optional, cast -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field -from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler -from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.run_condition import RunCondition -class GraphNode(BaseModel): - id: str - """node id""" +class GraphEdge(BaseModel): + source_node_id: str + """source node id""" - parent_id: Optional[str] = None - """parent node id, e.g. iteration/loop""" - - predecessor_node_id: Optional[str] = None - """predecessor node id""" - - descendant_node_ids: list[str] = [] - """descendant node ids""" + target_node_id: str + """target node id""" run_condition: Optional[RunCondition] = None - """condition to run the node""" + """condition to run the edge""" - node_config: dict - """original node config""" - source_edge_config: Optional[dict] = None - """original source edge config""" +class GraphStateRoute(BaseModel): + route_id: str + """route id""" - sub_graph: Optional["Graph"] = None - """sub graph of the node, e.g. iteration/loop sub graph""" + node_id: str + """node id""" - def add_child(self, node_id: str) -> None: - if node_id not in self.descendant_node_ids: - self.descendant_node_ids.append(node_id) - def get_run_condition_handler(self) -> Optional[RunConditionHandler]: - """ - Get run condition handler +class GraphState(BaseModel): + routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict) + """graph state routes (route_id: run_result)""" - :return: run condition handler - """ - if not self.run_condition: - return None + variable_pool: VariablePool + """variable pool""" - return ConditionManager.get_condition_handler( - run_condition=self.run_condition - ) + node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict) + """node results in route (route_id: run_result)""" class Graph(BaseModel): - graph_nodes: dict[str, GraphNode] = Field(default_factory=dict) - """graph nodes""" + root_node_id: str + """root node id of the graph""" - root_node: GraphNode - """root node of the graph""" + node_ids: list[str] = Field(default_factory=list) + """graph node ids""" - @model_validator(mode='after') - def add_root_node(cls, values): - root_node = values.root_node - values.graph_nodes[root_node.id] = root_node - return values + edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict) + """graph edge mapping""" + + run_state: GraphState + """graph run state""" @classmethod - def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph": + def init(cls, + graph_config: dict, + variable_pool: VariablePool, + root_node_id: Optional[str] = None) -> "Graph": """ Init graph - :param root_node_config: root node config - :param run_condition: run condition when root node parent is iteration/loop + :param graph_config: graph config + :param variable_pool: variable pool + :param root_node_id: root node id :return: graph """ - node_id = root_node_config.get('id') - if not node_id: - raise ValueError("Graph root node id is required") + # edge configs + edge_configs = graph_config.get('edges') + if edge_configs is None: + edge_configs = [] - root_node = GraphNode( - id=node_id, - parent_id=root_node_config.get('parentId'), - node_config=root_node_config, + edge_configs = cast(list, edge_configs) + + # reorganize edges mapping + edge_mapping: dict[str, list[GraphEdge]] = {} + target_edge_ids = set() + for edge_config in edge_configs: + source_node_id = edge_config.get('source') + if not source_node_id: + continue + + if source_node_id not in edge_mapping: + edge_mapping[source_node_id] = [] + + target_node_id = edge_config.get('target') + if not target_node_id: + continue + + target_edge_ids.add(target_node_id) + + # parse run condition + run_condition = None + if edge_config.get('sourceHandle'): + run_condition = RunCondition( + type='branch_identify', + branch_identify=edge_config.get('sourceHandle') + ) + + graph_edge = GraphEdge( + source_node_id=source_node_id, + target_node_id=edge_config.get('target'), + run_condition=run_condition + ) + + edge_mapping[source_node_id].append(graph_edge) + + # node configs + node_configs = graph_config.get('nodes') + if not node_configs: + raise ValueError("Graph must have at least one node") + + node_configs = cast(list, node_configs) + + # fetch nodes that have no predecessor node + root_node_configs = [] + for node_config in node_configs: + node_id = node_config.get('id') + if not node_id: + continue + + if node_id not in target_edge_ids: + root_node_configs.append(node_config) + + root_node_ids = [node_config.get('id') for node_config in root_node_configs] + + # fetch root node + if not root_node_id: + # if no root node id, use the START type node as root node + root_node_id = next((node_config for node_config in root_node_configs + if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) + + if not root_node_id or root_node_id not in root_node_ids: + raise ValueError(f"Root node id {root_node_id} not found in the graph") + + # fetch all node ids from root node + node_ids = [root_node_id] + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=root_node_id + ) + + # init graph + graph = cls( + root_node_id=root_node_id, + node_ids=node_ids, + edge_mapping=edge_mapping, + run_state=GraphState( + variable_pool=variable_pool + ) + ) + + return graph + + @classmethod + def _recursively_add_node_ids(cls, + node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + node_id: str) -> None: + """ + Recursively add node ids + + :param node_ids: node ids + :param edge_mapping: edge mapping + :param node_id: node id + """ + for graph_edge in edge_mapping.get(node_id, []): + if graph_edge.target_node_id in node_ids: + continue + + node_ids.append(graph_edge.target_node_id) + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=graph_edge.target_node_id + ) + def next_node_ids(self) -> list[str]: + """ + Get next node ids + """ + # todo + return [] + + def add_extra_edge(self, source_node_id: str, + target_node_id: str, + run_condition: Optional[RunCondition] = None) -> None: + """ + Add extra edge to the graph + + :param source_node_id: source node id + :param target_node_id: target node id + :param run_condition: run condition + """ + if source_node_id not in self.node_ids or target_node_id not in self.node_ids: + return + + if source_node_id not in self.edge_mapping: + self.edge_mapping[source_node_id] = [] + + if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]: + return + + graph_edge = GraphEdge( + source_node_id=source_node_id, + target_node_id=target_node_id, run_condition=run_condition ) - return cls(root_node=root_node) + self.edge_mapping[source_node_id].append(graph_edge) - def add_edge(self, edge_config: dict, - source_node_config: dict, - target_node_config: dict, - target_node_sub_graph: Optional["Graph"] = None, - run_condition: Optional[RunCondition] = None) -> None: + def get_leaf_node_ids(self) -> list[str]: """ - Add edge to the graph + Get leaf node ids of the graph - :param edge_config: edge config - :param source_node_config: source node config - :param target_node_config: target node config - :param target_node_sub_graph: sub graph - :param run_condition: run condition + :return: leaf node ids """ - source_node_id = source_node_config.get('id') - if not source_node_id: - return + leaf_node_ids = [] + for node_id in self.node_ids: + if node_id not in self.edge_mapping: + leaf_node_ids.append(node_id) + elif (len(self.edge_mapping[node_id]) == 1 + and self.edge_mapping[node_id][0].target_node_id == self.root_node_id): + leaf_node_ids.append(node_id) - if source_node_id not in self.graph_nodes: - return - - target_node_id = target_node_config.get('id') - if not target_node_id: - return - - source_node = self.graph_nodes.get(source_node_id) - if not source_node: - return - - source_node.add_child(target_node_id) - - if target_node_id not in self.graph_nodes: - target_graph_node = GraphNode( - id=target_node_id, - parent_id=source_node_config.get('parentId'), - predecessor_node_id=source_node_id, - node_config=target_node_config, - run_condition=run_condition, - source_edge_config=edge_config, - sub_graph=target_node_sub_graph - ) - - self.add_graph_node(target_graph_node) - else: - target_node = self.graph_nodes.get(target_node_id) - if not target_node: - return - - target_node.predecessor_node_id = source_node_id - target_node.run_condition = run_condition - target_node.source_edge_config = edge_config - target_node.sub_graph = target_node_sub_graph - - def get_leaf_nodes(self) -> list[GraphNode]: - """ - Get leaf nodes of the graph - - :return: leaf nodes - """ - leaf_nodes = [] - for node_id, graph_node in self.graph_nodes.items(): - if ( - not graph_node.descendant_node_ids # has no child - or # or has only one child and the child is the root node - ( - graph_node.descendant_node_ids - and graph_node.descendant_node_ids[0] == self.root_node.id - ) - ): - leaf_nodes.append(graph_node) - - return leaf_nodes - - def get_descendant_graphs(self, node_id: str) -> list["Graph"]: - """ - Get descendant graphs of the specific node - - :param node_id: node id - :return: descendant graphs - """ - if node_id not in self.graph_nodes: - return [] - - graph_node = self.graph_nodes.get(node_id) - if not graph_node or not graph_node.descendant_node_ids: - return [] - - descendant_graphs: list[Graph] = [] - for descendant_node_id in graph_node.descendant_node_ids: - descendant_graph_node = self.graph_nodes.get(descendant_node_id) - if not descendant_graph_node: - continue - - descendants_graph = Graph(root_node=descendant_graph_node) - for sub_descendant_node_id in descendant_graph_node.descendant_node_ids: - descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id) - - descendant_graphs.append(descendants_graph) - - return descendant_graphs - - def add_graph_node(self, graph_node: GraphNode) -> None: - """ - Add graph node to the graph - - :param graph_node: graph node - """ - if graph_node.id in self.graph_nodes: - return - - self.graph_nodes[graph_node.id] = graph_node - - def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None: - """ - Add descendants graph nodes - - :param predecessor_graph: predecessor graph - :param node_id: node id - """ - if node_id not in predecessor_graph.graph_nodes: - return - - graph_node = predecessor_graph.graph_nodes.get(node_id) - if not graph_node: - return - - if graph_node.id not in self.graph_nodes: - self.add_graph_node(graph_node) - - for child_node_id in graph_node.descendant_node_ids: - self.add_descendants_graph_nodes(predecessor_graph, child_node_id) + return leaf_node_ids