refactor graph

This commit is contained in:
takatost 2024-07-06 03:18:02 +08:00
parent 1b6cd975f3
commit 03f56a05eb

View File

@ -1,218 +1,224 @@
from typing import Optional
from typing import Optional, cast
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.run_condition import RunCondition
class GraphNode(BaseModel):
id: str
"""node id"""
class GraphEdge(BaseModel):
source_node_id: str
"""source node id"""
parent_id: Optional[str] = None
"""parent node id, e.g. iteration/loop"""
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
descendant_node_ids: list[str] = []
"""descendant node ids"""
target_node_id: str
"""target node id"""
run_condition: Optional[RunCondition] = None
"""condition to run the node"""
"""condition to run the edge"""
node_config: dict
"""original node config"""
source_edge_config: Optional[dict] = None
"""original source edge config"""
class GraphStateRoute(BaseModel):
route_id: str
"""route id"""
sub_graph: Optional["Graph"] = None
"""sub graph of the node, e.g. iteration/loop sub graph"""
node_id: str
"""node id"""
def add_child(self, node_id: str) -> None:
if node_id not in self.descendant_node_ids:
self.descendant_node_ids.append(node_id)
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
"""
Get run condition handler
class GraphState(BaseModel):
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
"""graph state routes (route_id: run_result)"""
:return: run condition handler
"""
if not self.run_condition:
return None
variable_pool: VariablePool
"""variable pool"""
return ConditionManager.get_condition_handler(
run_condition=self.run_condition
)
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
"""node results in route (route_id: run_result)"""
class Graph(BaseModel):
graph_nodes: dict[str, GraphNode] = Field(default_factory=dict)
"""graph nodes"""
root_node_id: str
"""root node id of the graph"""
root_node: GraphNode
"""root node of the graph"""
node_ids: list[str] = Field(default_factory=list)
"""graph node ids"""
@model_validator(mode='after')
def add_root_node(cls, values):
root_node = values.root_node
values.graph_nodes[root_node.id] = root_node
return values
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
"""graph edge mapping"""
run_state: GraphState
"""graph run state"""
@classmethod
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
def init(cls,
graph_config: dict,
variable_pool: VariablePool,
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param root_node_config: root node config
:param run_condition: run condition when root node parent is iteration/loop
:param graph_config: graph config
:param variable_pool: variable pool
:param root_node_id: root node id
:return: graph
"""
node_id = root_node_config.get('id')
if not node_id:
raise ValueError("Graph root node id is required")
# edge configs
edge_configs = graph_config.get('edges')
if edge_configs is None:
edge_configs = []
root_node = GraphNode(
id=node_id,
parent_id=root_node_config.get('parentId'),
node_config=root_node_config,
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get('target')
if not target_node_id:
continue
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle'):
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=edge_config.get('target'),
run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=root_node_id
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
edge_mapping=edge_mapping,
run_state=GraphState(
variable_pool=variable_pool
)
)
return graph
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
def next_node_ids(self) -> list[str]:
"""
Get next node ids
"""
# todo
return []
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
return cls(root_node=root_node)
self.edge_mapping[source_node_id].append(graph_edge)
def add_edge(self, edge_config: dict,
source_node_config: dict,
target_node_config: dict,
target_node_sub_graph: Optional["Graph"] = None,
run_condition: Optional[RunCondition] = None) -> None:
def get_leaf_node_ids(self) -> list[str]:
"""
Add edge to the graph
Get leaf node ids of the graph
:param edge_config: edge config
:param source_node_config: source node config
:param target_node_config: target node config
:param target_node_sub_graph: sub graph
:param run_condition: run condition
:return: leaf node ids
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping:
leaf_node_ids.append(node_id)
elif (len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
leaf_node_ids.append(node_id)
if source_node_id not in self.graph_nodes:
return
target_node_id = target_node_config.get('id')
if not target_node_id:
return
source_node = self.graph_nodes.get(source_node_id)
if not source_node:
return
source_node.add_child(target_node_id)
if target_node_id not in self.graph_nodes:
target_graph_node = GraphNode(
id=target_node_id,
parent_id=source_node_config.get('parentId'),
predecessor_node_id=source_node_id,
node_config=target_node_config,
run_condition=run_condition,
source_edge_config=edge_config,
sub_graph=target_node_sub_graph
)
self.add_graph_node(target_graph_node)
else:
target_node = self.graph_nodes.get(target_node_id)
if not target_node:
return
target_node.predecessor_node_id = source_node_id
target_node.run_condition = run_condition
target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph
def get_leaf_nodes(self) -> list[GraphNode]:
"""
Get leaf nodes of the graph
:return: leaf nodes
"""
leaf_nodes = []
for node_id, graph_node in self.graph_nodes.items():
if (
not graph_node.descendant_node_ids # has no child
or # or has only one child and the child is the root node
(
graph_node.descendant_node_ids
and graph_node.descendant_node_ids[0] == self.root_node.id
)
):
leaf_nodes.append(graph_node)
return leaf_nodes
def get_descendant_graphs(self, node_id: str) -> list["Graph"]:
"""
Get descendant graphs of the specific node
:param node_id: node id
:return: descendant graphs
"""
if node_id not in self.graph_nodes:
return []
graph_node = self.graph_nodes.get(node_id)
if not graph_node or not graph_node.descendant_node_ids:
return []
descendant_graphs: list[Graph] = []
for descendant_node_id in graph_node.descendant_node_ids:
descendant_graph_node = self.graph_nodes.get(descendant_node_id)
if not descendant_graph_node:
continue
descendants_graph = Graph(root_node=descendant_graph_node)
for sub_descendant_node_id in descendant_graph_node.descendant_node_ids:
descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id)
descendant_graphs.append(descendants_graph)
return descendant_graphs
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
self.graph_nodes[graph_node.id] = graph_node
def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None:
"""
Add descendants graph nodes
:param predecessor_graph: predecessor graph
:param node_id: node id
"""
if node_id not in predecessor_graph.graph_nodes:
return
graph_node = predecessor_graph.graph_nodes.get(node_id)
if not graph_node:
return
if graph_node.id not in self.graph_nodes:
self.add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids:
self.add_descendants_graph_nodes(predecessor_graph, child_node_id)
return leaf_node_ids