refactor graph

This commit is contained in:
takatost 2024-07-07 23:08:45 +08:00
parent fed068ac2e
commit 1adaf42f9d
4 changed files with 265 additions and 31 deletions

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.run_condition import RunCondition
@ -10,15 +9,15 @@ class RunConditionHandler(ABC):
@abstractmethod
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
raise NotImplementedError

View File

@ -1,25 +1,29 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
if not predecessor_node_result.edge_source_handle:
run_state = graph.run_state
node_route_result = run_state.node_route_results.get(source_node_id)
if not node_route_result:
return False
return self.condition.branch_identify == predecessor_node_result.edge_source_handle
if not node_route_result.edge_source_handle:
return False
return self.condition.branch_identify == node_route_result.edge_source_handle

View File

@ -1,19 +1,18 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
if not self.condition.conditions:
@ -22,7 +21,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition
condition_processor = ConditionProcessor()
compare_result, _ = condition_processor.process(
variable_pool=graph_runtime_state.variable_pool,
variable_pool=graph.run_state.variable_pool,
logical_operator="and",
conditions=self.condition.conditions
)

View File

@ -1,9 +1,11 @@
import uuid
from typing import Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.run_condition import RunCondition
@ -18,6 +20,14 @@ class GraphEdge(BaseModel):
"""condition to run the edge"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random uuid parallel id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if exists"""
class GraphStateRoute(BaseModel):
route_id: str
"""route id"""
@ -28,13 +38,21 @@ class GraphStateRoute(BaseModel):
class GraphState(BaseModel):
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
"""graph state routes (route_id: run_result)"""
"""graph state routes (source_node_id: routes)"""
variable_pool: VariablePool
"""variable pool"""
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
"""node results in route (route_id: run_result)"""
"""node results in route (node_id: run_result)"""
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""
class Graph(BaseModel):
@ -45,7 +63,13 @@ class Graph(BaseModel):
"""graph node ids"""
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
"""graph edge mapping"""
"""graph edge mapping (source node id: edges)"""
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
"""graph parallel mapping (parallel id: parallel)"""
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
"""graph node parallel mapping (node id: parallel id)"""
run_state: GraphState
"""graph run state"""
@ -139,6 +163,16 @@ class Graph(BaseModel):
node_id=root_node_id
)
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
# init graph
graph = cls(
root_node_id=root_node_id,
@ -146,7 +180,9 @@ class Graph(BaseModel):
edge_mapping=edge_mapping,
run_state=GraphState(
variable_pool=variable_pool
)
),
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
return graph
@ -173,12 +209,48 @@ class Graph(BaseModel):
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
def next_node_ids(self) -> list[str]:
def next_node_ids(self) -> list[NextGraphNode]:
"""
Get next node ids
"""
# todo
return []
# get current node ids in state
if not self.run_state.routes:
return [NextGraphNode(node_id=self.root_node_id)]
route_final_graph_edges: list[GraphEdge] = []
for route in self.run_state.routes[self.root_node_id]:
graph_edges = self.edge_mapping.get(route.node_id)
if not graph_edges:
continue
for edge in graph_edges:
if edge.target_node_id not in self.run_state.routes:
route_final_graph_edges.append(edge)
next_graph_nodes = []
for route_final_graph_edge in route_final_graph_edges:
node_id = route_final_graph_edge.target_node_id
# check condition
if route_final_graph_edge.run_condition:
result = ConditionManager.get_condition_handler(
run_condition=route_final_graph_edge.run_condition
).check(
source_node_id=route_final_graph_edge.source_node_id,
target_node_id=route_final_graph_edge.target_node_id,
graph=self
)
if not result:
continue
parallel = None
if route_final_graph_edge.target_node_id in self.node_parallel_mapping:
parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]]
next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
return next_graph_nodes
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
@ -222,3 +294,163 @@ class Graph(BaseModel):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str]) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
"""
target_node_edges = edge_mapping.get(start_node_id, [])
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_node_ids = [graph_edge.target_node_id
for graph_edge in target_node_edges if graph_edge.run_condition is not None]
# any target node id in node_parallel_mapping
if parallel_node_ids:
# all parallel_node_ids in node_parallel_mapping
parent_parallel_id = None
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
parallel = GraphParallel(parent_parallel_id=parent_parallel_id)
parallel_mapping[parallel.id] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
parallel_node_ids=parallel_node_ids
)
node_parallel_mapping.update({node_id: parallel.id for node_id in in_branch_node_ids})
for graph_edge in target_node_edges:
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
@classmethod
def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str) -> None:
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if (graph_edge.target_node_id != merge_node_id
and graph_edge.target_node_id not in branch_node_ids):
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id
)
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
parallel_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_node_id in parallel_node_ids:
routes_node_ids[parallel_node_id] = []
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_node_id,
routes_node_ids=routes_node_ids[parallel_node_id]
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if branch_node_id != branch_node_id2 and node_id in inner_route2:
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = [branch_node_id]
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(cls,
edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: list[str]) -> None:
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids
)