mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 05:25:54 +08:00
refactor graph
This commit is contained in:
parent
1b6cd975f3
commit
03f56a05eb
@ -1,218 +1,224 @@
|
|||||||
from typing import Optional
|
from typing import Optional, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
class GraphNode(BaseModel):
|
class GraphEdge(BaseModel):
|
||||||
id: str
|
source_node_id: str
|
||||||
"""node id"""
|
"""source node id"""
|
||||||
|
|
||||||
parent_id: Optional[str] = None
|
target_node_id: str
|
||||||
"""parent node id, e.g. iteration/loop"""
|
"""target node id"""
|
||||||
|
|
||||||
predecessor_node_id: Optional[str] = None
|
|
||||||
"""predecessor node id"""
|
|
||||||
|
|
||||||
descendant_node_ids: list[str] = []
|
|
||||||
"""descendant node ids"""
|
|
||||||
|
|
||||||
run_condition: Optional[RunCondition] = None
|
run_condition: Optional[RunCondition] = None
|
||||||
"""condition to run the node"""
|
"""condition to run the edge"""
|
||||||
|
|
||||||
node_config: dict
|
|
||||||
"""original node config"""
|
|
||||||
|
|
||||||
source_edge_config: Optional[dict] = None
|
class GraphStateRoute(BaseModel):
|
||||||
"""original source edge config"""
|
route_id: str
|
||||||
|
"""route id"""
|
||||||
|
|
||||||
sub_graph: Optional["Graph"] = None
|
node_id: str
|
||||||
"""sub graph of the node, e.g. iteration/loop sub graph"""
|
"""node id"""
|
||||||
|
|
||||||
def add_child(self, node_id: str) -> None:
|
|
||||||
if node_id not in self.descendant_node_ids:
|
|
||||||
self.descendant_node_ids.append(node_id)
|
|
||||||
|
|
||||||
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
|
class GraphState(BaseModel):
|
||||||
"""
|
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
|
||||||
Get run condition handler
|
"""graph state routes (route_id: run_result)"""
|
||||||
|
|
||||||
:return: run condition handler
|
variable_pool: VariablePool
|
||||||
"""
|
"""variable pool"""
|
||||||
if not self.run_condition:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return ConditionManager.get_condition_handler(
|
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
|
||||||
run_condition=self.run_condition
|
"""node results in route (route_id: run_result)"""
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
graph_nodes: dict[str, GraphNode] = Field(default_factory=dict)
|
root_node_id: str
|
||||||
"""graph nodes"""
|
"""root node id of the graph"""
|
||||||
|
|
||||||
root_node: GraphNode
|
node_ids: list[str] = Field(default_factory=list)
|
||||||
"""root node of the graph"""
|
"""graph node ids"""
|
||||||
|
|
||||||
@model_validator(mode='after')
|
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
||||||
def add_root_node(cls, values):
|
"""graph edge mapping"""
|
||||||
root_node = values.root_node
|
|
||||||
values.graph_nodes[root_node.id] = root_node
|
run_state: GraphState
|
||||||
return values
|
"""graph run state"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
|
def init(cls,
|
||||||
|
graph_config: dict,
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
root_node_id: Optional[str] = None) -> "Graph":
|
||||||
"""
|
"""
|
||||||
Init graph
|
Init graph
|
||||||
|
|
||||||
:param root_node_config: root node config
|
:param graph_config: graph config
|
||||||
:param run_condition: run condition when root node parent is iteration/loop
|
:param variable_pool: variable pool
|
||||||
|
:param root_node_id: root node id
|
||||||
:return: graph
|
:return: graph
|
||||||
"""
|
"""
|
||||||
node_id = root_node_config.get('id')
|
# edge configs
|
||||||
if not node_id:
|
edge_configs = graph_config.get('edges')
|
||||||
raise ValueError("Graph root node id is required")
|
if edge_configs is None:
|
||||||
|
edge_configs = []
|
||||||
|
|
||||||
root_node = GraphNode(
|
edge_configs = cast(list, edge_configs)
|
||||||
id=node_id,
|
|
||||||
parent_id=root_node_config.get('parentId'),
|
# reorganize edges mapping
|
||||||
node_config=root_node_config,
|
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||||
|
target_edge_ids = set()
|
||||||
|
for edge_config in edge_configs:
|
||||||
|
source_node_id = edge_config.get('source')
|
||||||
|
if not source_node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if source_node_id not in edge_mapping:
|
||||||
|
edge_mapping[source_node_id] = []
|
||||||
|
|
||||||
|
target_node_id = edge_config.get('target')
|
||||||
|
if not target_node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_edge_ids.add(target_node_id)
|
||||||
|
|
||||||
|
# parse run condition
|
||||||
|
run_condition = None
|
||||||
|
if edge_config.get('sourceHandle'):
|
||||||
|
run_condition = RunCondition(
|
||||||
|
type='branch_identify',
|
||||||
|
branch_identify=edge_config.get('sourceHandle')
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_edge = GraphEdge(
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
target_node_id=edge_config.get('target'),
|
||||||
|
run_condition=run_condition
|
||||||
|
)
|
||||||
|
|
||||||
|
edge_mapping[source_node_id].append(graph_edge)
|
||||||
|
|
||||||
|
# node configs
|
||||||
|
node_configs = graph_config.get('nodes')
|
||||||
|
if not node_configs:
|
||||||
|
raise ValueError("Graph must have at least one node")
|
||||||
|
|
||||||
|
node_configs = cast(list, node_configs)
|
||||||
|
|
||||||
|
# fetch nodes that have no predecessor node
|
||||||
|
root_node_configs = []
|
||||||
|
for node_config in node_configs:
|
||||||
|
node_id = node_config.get('id')
|
||||||
|
if not node_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if node_id not in target_edge_ids:
|
||||||
|
root_node_configs.append(node_config)
|
||||||
|
|
||||||
|
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||||
|
|
||||||
|
# fetch root node
|
||||||
|
if not root_node_id:
|
||||||
|
# if no root node id, use the START type node as root node
|
||||||
|
root_node_id = next((node_config for node_config in root_node_configs
|
||||||
|
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||||
|
|
||||||
|
if not root_node_id or root_node_id not in root_node_ids:
|
||||||
|
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||||
|
|
||||||
|
# fetch all node ids from root node
|
||||||
|
node_ids = [root_node_id]
|
||||||
|
cls._recursively_add_node_ids(
|
||||||
|
node_ids=node_ids,
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
node_id=root_node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# init graph
|
||||||
|
graph = cls(
|
||||||
|
root_node_id=root_node_id,
|
||||||
|
node_ids=node_ids,
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
run_state=GraphState(
|
||||||
|
variable_pool=variable_pool
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _recursively_add_node_ids(cls,
|
||||||
|
node_ids: list[str],
|
||||||
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
node_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Recursively add node ids
|
||||||
|
|
||||||
|
:param node_ids: node ids
|
||||||
|
:param edge_mapping: edge mapping
|
||||||
|
:param node_id: node id
|
||||||
|
"""
|
||||||
|
for graph_edge in edge_mapping.get(node_id, []):
|
||||||
|
if graph_edge.target_node_id in node_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_ids.append(graph_edge.target_node_id)
|
||||||
|
cls._recursively_add_node_ids(
|
||||||
|
node_ids=node_ids,
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
node_id=graph_edge.target_node_id
|
||||||
|
)
|
||||||
|
def next_node_ids(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get next node ids
|
||||||
|
"""
|
||||||
|
# todo
|
||||||
|
return []
|
||||||
|
|
||||||
|
def add_extra_edge(self, source_node_id: str,
|
||||||
|
target_node_id: str,
|
||||||
|
run_condition: Optional[RunCondition] = None) -> None:
|
||||||
|
"""
|
||||||
|
Add extra edge to the graph
|
||||||
|
|
||||||
|
:param source_node_id: source node id
|
||||||
|
:param target_node_id: target node id
|
||||||
|
:param run_condition: run condition
|
||||||
|
"""
|
||||||
|
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
if source_node_id not in self.edge_mapping:
|
||||||
|
self.edge_mapping[source_node_id] = []
|
||||||
|
|
||||||
|
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||||
|
return
|
||||||
|
|
||||||
|
graph_edge = GraphEdge(
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
target_node_id=target_node_id,
|
||||||
run_condition=run_condition
|
run_condition=run_condition
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(root_node=root_node)
|
self.edge_mapping[source_node_id].append(graph_edge)
|
||||||
|
|
||||||
def add_edge(self, edge_config: dict,
|
def get_leaf_node_ids(self) -> list[str]:
|
||||||
source_node_config: dict,
|
|
||||||
target_node_config: dict,
|
|
||||||
target_node_sub_graph: Optional["Graph"] = None,
|
|
||||||
run_condition: Optional[RunCondition] = None) -> None:
|
|
||||||
"""
|
"""
|
||||||
Add edge to the graph
|
Get leaf node ids of the graph
|
||||||
|
|
||||||
:param edge_config: edge config
|
:return: leaf node ids
|
||||||
:param source_node_config: source node config
|
|
||||||
:param target_node_config: target node config
|
|
||||||
:param target_node_sub_graph: sub graph
|
|
||||||
:param run_condition: run condition
|
|
||||||
"""
|
"""
|
||||||
source_node_id = source_node_config.get('id')
|
leaf_node_ids = []
|
||||||
if not source_node_id:
|
for node_id in self.node_ids:
|
||||||
return
|
if node_id not in self.edge_mapping:
|
||||||
|
leaf_node_ids.append(node_id)
|
||||||
|
elif (len(self.edge_mapping[node_id]) == 1
|
||||||
|
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
|
||||||
|
leaf_node_ids.append(node_id)
|
||||||
|
|
||||||
if source_node_id not in self.graph_nodes:
|
return leaf_node_ids
|
||||||
return
|
|
||||||
|
|
||||||
target_node_id = target_node_config.get('id')
|
|
||||||
if not target_node_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
source_node = self.graph_nodes.get(source_node_id)
|
|
||||||
if not source_node:
|
|
||||||
return
|
|
||||||
|
|
||||||
source_node.add_child(target_node_id)
|
|
||||||
|
|
||||||
if target_node_id not in self.graph_nodes:
|
|
||||||
target_graph_node = GraphNode(
|
|
||||||
id=target_node_id,
|
|
||||||
parent_id=source_node_config.get('parentId'),
|
|
||||||
predecessor_node_id=source_node_id,
|
|
||||||
node_config=target_node_config,
|
|
||||||
run_condition=run_condition,
|
|
||||||
source_edge_config=edge_config,
|
|
||||||
sub_graph=target_node_sub_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
self.add_graph_node(target_graph_node)
|
|
||||||
else:
|
|
||||||
target_node = self.graph_nodes.get(target_node_id)
|
|
||||||
if not target_node:
|
|
||||||
return
|
|
||||||
|
|
||||||
target_node.predecessor_node_id = source_node_id
|
|
||||||
target_node.run_condition = run_condition
|
|
||||||
target_node.source_edge_config = edge_config
|
|
||||||
target_node.sub_graph = target_node_sub_graph
|
|
||||||
|
|
||||||
def get_leaf_nodes(self) -> list[GraphNode]:
|
|
||||||
"""
|
|
||||||
Get leaf nodes of the graph
|
|
||||||
|
|
||||||
:return: leaf nodes
|
|
||||||
"""
|
|
||||||
leaf_nodes = []
|
|
||||||
for node_id, graph_node in self.graph_nodes.items():
|
|
||||||
if (
|
|
||||||
not graph_node.descendant_node_ids # has no child
|
|
||||||
or # or has only one child and the child is the root node
|
|
||||||
(
|
|
||||||
graph_node.descendant_node_ids
|
|
||||||
and graph_node.descendant_node_ids[0] == self.root_node.id
|
|
||||||
)
|
|
||||||
):
|
|
||||||
leaf_nodes.append(graph_node)
|
|
||||||
|
|
||||||
return leaf_nodes
|
|
||||||
|
|
||||||
def get_descendant_graphs(self, node_id: str) -> list["Graph"]:
|
|
||||||
"""
|
|
||||||
Get descendant graphs of the specific node
|
|
||||||
|
|
||||||
:param node_id: node id
|
|
||||||
:return: descendant graphs
|
|
||||||
"""
|
|
||||||
if node_id not in self.graph_nodes:
|
|
||||||
return []
|
|
||||||
|
|
||||||
graph_node = self.graph_nodes.get(node_id)
|
|
||||||
if not graph_node or not graph_node.descendant_node_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
descendant_graphs: list[Graph] = []
|
|
||||||
for descendant_node_id in graph_node.descendant_node_ids:
|
|
||||||
descendant_graph_node = self.graph_nodes.get(descendant_node_id)
|
|
||||||
if not descendant_graph_node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
descendants_graph = Graph(root_node=descendant_graph_node)
|
|
||||||
for sub_descendant_node_id in descendant_graph_node.descendant_node_ids:
|
|
||||||
descendants_graph.add_descendants_graph_nodes(self, sub_descendant_node_id)
|
|
||||||
|
|
||||||
descendant_graphs.append(descendants_graph)
|
|
||||||
|
|
||||||
return descendant_graphs
|
|
||||||
|
|
||||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
|
||||||
"""
|
|
||||||
Add graph node to the graph
|
|
||||||
|
|
||||||
:param graph_node: graph node
|
|
||||||
"""
|
|
||||||
if graph_node.id in self.graph_nodes:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.graph_nodes[graph_node.id] = graph_node
|
|
||||||
|
|
||||||
def add_descendants_graph_nodes(self, predecessor_graph: "Graph", node_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Add descendants graph nodes
|
|
||||||
|
|
||||||
:param predecessor_graph: predecessor graph
|
|
||||||
:param node_id: node id
|
|
||||||
"""
|
|
||||||
if node_id not in predecessor_graph.graph_nodes:
|
|
||||||
return
|
|
||||||
|
|
||||||
graph_node = predecessor_graph.graph_nodes.get(node_id)
|
|
||||||
if not graph_node:
|
|
||||||
return
|
|
||||||
|
|
||||||
if graph_node.id not in self.graph_nodes:
|
|
||||||
self.add_graph_node(graph_node)
|
|
||||||
|
|
||||||
for child_node_id in graph_node.descendant_node_ids:
|
|
||||||
self.add_descendants_graph_nodes(predecessor_graph, child_node_id)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user