mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 05:55:59 +08:00
refactor graph
This commit is contained in:
parent
fed068ac2e
commit
1adaf42f9d
@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
@ -10,15 +9,15 @@ class RunConditionHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_node: "GraphNode",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
predecessor_node_result: NodeRunResult) -> bool:
|
||||
source_node_id: str,
|
||||
target_node_id: str,
|
||||
graph: "Graph") -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_node: graph node
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param predecessor_node_result: predecessor node result
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param graph: graph
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -1,25 +1,29 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_node: "GraphNode",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
predecessor_node_result: NodeRunResult) -> bool:
|
||||
source_node_id: str,
|
||||
target_node_id: str,
|
||||
graph: "Graph") -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_node: graph node
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param predecessor_node_result: predecessor node result
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param graph: graph
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
if not predecessor_node_result.edge_source_handle:
|
||||
run_state = graph.run_state
|
||||
node_route_result = run_state.node_route_results.get(source_node_id)
|
||||
if not node_route_result:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == predecessor_node_result.edge_source_handle
|
||||
if not node_route_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == node_route_result.edge_source_handle
|
||||
|
@ -1,19 +1,18 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_node: "GraphNode",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
predecessor_node_result: NodeRunResult) -> bool:
|
||||
source_node_id: str,
|
||||
target_node_id: str,
|
||||
graph: "Graph") -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_node: graph node
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param predecessor_node_result: predecessor node result
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param graph: graph
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
@ -22,7 +21,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
compare_result, _ = condition_processor.process(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
variable_pool=graph.run_state.variable_pool,
|
||||
logical_operator="and",
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
@ -18,6 +20,14 @@ class GraphEdge(BaseModel):
|
||||
"""condition to run the edge"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random uuid parallel id"""
|
||||
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if exists"""
|
||||
|
||||
|
||||
class GraphStateRoute(BaseModel):
|
||||
route_id: str
|
||||
"""route id"""
|
||||
@ -28,13 +38,21 @@ class GraphStateRoute(BaseModel):
|
||||
|
||||
class GraphState(BaseModel):
|
||||
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
|
||||
"""graph state routes (route_id: run_result)"""
|
||||
"""graph state routes (source_node_id: routes)"""
|
||||
|
||||
variable_pool: VariablePool
|
||||
"""variable pool"""
|
||||
|
||||
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
|
||||
"""node results in route (route_id: run_result)"""
|
||||
"""node results in route (node_id: run_result)"""
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
@ -45,7 +63,13 @@ class Graph(BaseModel):
|
||||
"""graph node ids"""
|
||||
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
||||
"""graph edge mapping"""
|
||||
"""graph edge mapping (source node id: edges)"""
|
||||
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
|
||||
"""graph parallel mapping (parallel id: parallel)"""
|
||||
|
||||
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
|
||||
"""graph node parallel mapping (node id: parallel id)"""
|
||||
|
||||
run_state: GraphState
|
||||
"""graph run state"""
|
||||
@ -139,6 +163,16 @@ class Graph(BaseModel):
|
||||
node_id=root_node_id
|
||||
)
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
@ -146,7 +180,9 @@ class Graph(BaseModel):
|
||||
edge_mapping=edge_mapping,
|
||||
run_state=GraphState(
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
),
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
return graph
|
||||
@ -173,12 +209,48 @@ class Graph(BaseModel):
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
def next_node_ids(self) -> list[str]:
|
||||
|
||||
def next_node_ids(self) -> list[NextGraphNode]:
|
||||
"""
|
||||
Get next node ids
|
||||
"""
|
||||
# todo
|
||||
return []
|
||||
# get current node ids in state
|
||||
if not self.run_state.routes:
|
||||
return [NextGraphNode(node_id=self.root_node_id)]
|
||||
|
||||
route_final_graph_edges: list[GraphEdge] = []
|
||||
for route in self.run_state.routes[self.root_node_id]:
|
||||
graph_edges = self.edge_mapping.get(route.node_id)
|
||||
if not graph_edges:
|
||||
continue
|
||||
|
||||
for edge in graph_edges:
|
||||
if edge.target_node_id not in self.run_state.routes:
|
||||
route_final_graph_edges.append(edge)
|
||||
|
||||
next_graph_nodes = []
|
||||
for route_final_graph_edge in route_final_graph_edges:
|
||||
node_id = route_final_graph_edge.target_node_id
|
||||
# check condition
|
||||
if route_final_graph_edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
run_condition=route_final_graph_edge.run_condition
|
||||
).check(
|
||||
source_node_id=route_final_graph_edge.source_node_id,
|
||||
target_node_id=route_final_graph_edge.target_node_id,
|
||||
graph=self
|
||||
)
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
parallel = None
|
||||
if route_final_graph_edge.target_node_id in self.node_parallel_mapping:
|
||||
parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]]
|
||||
|
||||
next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
|
||||
|
||||
return next_graph_nodes
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
@ -222,3 +294,163 @@ class Graph(BaseModel):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str]) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_node_ids = [graph_edge.target_node_id
|
||||
for graph_edge in target_node_edges if graph_edge.run_condition is not None]
|
||||
|
||||
# any target node id in node_parallel_mapping
|
||||
if parallel_node_ids:
|
||||
# all parallel_node_ids in node_parallel_mapping
|
||||
parent_parallel_id = None
|
||||
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
|
||||
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
|
||||
|
||||
parallel = GraphParallel(parent_parallel_id=parent_parallel_id)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
parallel_node_ids=parallel_node_ids
|
||||
)
|
||||
|
||||
node_parallel_mapping.update({node_id: parallel.id for node_id in in_branch_node_ids})
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if (graph_edge.target_node_id != merge_node_id
|
||||
and graph_edge.target_node_id not in branch_node_ids):
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_node_ids: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_node_id in parallel_node_ids:
|
||||
routes_node_ids[parallel_node_id] = []
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_node_id]
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if branch_node_id != branch_node_id2 and node_id in inner_route2:
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = [branch_node_id]
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: list[str]) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user