mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-19 02:15:54 +08:00
refactor graph
This commit is contained in:
parent
fed068ac2e
commit
1adaf42f9d
@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
@ -10,15 +9,15 @@ class RunConditionHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_node: "GraphNode",
|
source_node_id: str,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
target_node_id: str,
|
||||||
predecessor_node_result: NodeRunResult) -> bool:
|
graph: "Graph") -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_node: graph node
|
:param source_node_id: source node id
|
||||||
:param graph_runtime_state: graph runtime state
|
:param target_node_id: target node id
|
||||||
:param predecessor_node_result: predecessor node result
|
:param graph: graph
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,25 +1,29 @@
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
|
||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
|
||||||
|
|
||||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||||
|
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_node: "GraphNode",
|
source_node_id: str,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
target_node_id: str,
|
||||||
predecessor_node_result: NodeRunResult) -> bool:
|
graph: "Graph") -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_node: graph node
|
:param source_node_id: source node id
|
||||||
:param graph_runtime_state: graph runtime state
|
:param target_node_id: target node id
|
||||||
:param predecessor_node_result: predecessor node result
|
:param graph: graph
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.branch_identify:
|
if not self.condition.branch_identify:
|
||||||
raise Exception("Branch identify is required")
|
raise Exception("Branch identify is required")
|
||||||
|
|
||||||
if not predecessor_node_result.edge_source_handle:
|
run_state = graph.run_state
|
||||||
|
node_route_result = run_state.node_route_results.get(source_node_id)
|
||||||
|
if not node_route_result:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.condition.branch_identify == predecessor_node_result.edge_source_handle
|
if not node_route_result.edge_source_handle:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.condition.branch_identify == node_route_result.edge_source_handle
|
||||||
|
@ -1,19 +1,18 @@
|
|||||||
from core.workflow.entities.node_entities import NodeRunResult
|
|
||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
|
||||||
|
|
||||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_node: "GraphNode",
|
source_node_id: str,
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
target_node_id: str,
|
||||||
predecessor_node_result: NodeRunResult) -> bool:
|
graph: "Graph") -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_node: graph node
|
:param source_node_id: source node id
|
||||||
:param graph_runtime_state: graph runtime state
|
:param target_node_id: target node id
|
||||||
:param predecessor_node_result: predecessor node result
|
:param graph: graph
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.conditions:
|
if not self.condition.conditions:
|
||||||
@ -22,7 +21,7 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
|||||||
# process condition
|
# process condition
|
||||||
condition_processor = ConditionProcessor()
|
condition_processor = ConditionProcessor()
|
||||||
compare_result, _ = condition_processor.process(
|
compare_result, _ = condition_processor.process(
|
||||||
variable_pool=graph_runtime_state.variable_pool,
|
variable_pool=graph.run_state.variable_pool,
|
||||||
logical_operator="and",
|
logical_operator="and",
|
||||||
conditions=self.condition.conditions
|
conditions=self.condition.conditions
|
||||||
)
|
)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
import uuid
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +20,14 @@ class GraphEdge(BaseModel):
|
|||||||
"""condition to run the edge"""
|
"""condition to run the edge"""
|
||||||
|
|
||||||
|
|
||||||
|
class GraphParallel(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
"""random uuid parallel id"""
|
||||||
|
|
||||||
|
parent_parallel_id: Optional[str] = None
|
||||||
|
"""parent parallel id if exists"""
|
||||||
|
|
||||||
|
|
||||||
class GraphStateRoute(BaseModel):
|
class GraphStateRoute(BaseModel):
|
||||||
route_id: str
|
route_id: str
|
||||||
"""route id"""
|
"""route id"""
|
||||||
@ -28,13 +38,21 @@ class GraphStateRoute(BaseModel):
|
|||||||
|
|
||||||
class GraphState(BaseModel):
|
class GraphState(BaseModel):
|
||||||
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
|
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
|
||||||
"""graph state routes (route_id: run_result)"""
|
"""graph state routes (source_node_id: routes)"""
|
||||||
|
|
||||||
variable_pool: VariablePool
|
variable_pool: VariablePool
|
||||||
"""variable pool"""
|
"""variable pool"""
|
||||||
|
|
||||||
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
|
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
|
||||||
"""node results in route (route_id: run_result)"""
|
"""node results in route (node_id: run_result)"""
|
||||||
|
|
||||||
|
|
||||||
|
class NextGraphNode(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
"""next node id"""
|
||||||
|
|
||||||
|
parallel: Optional[GraphParallel] = None
|
||||||
|
"""parallel"""
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
@ -45,7 +63,13 @@ class Graph(BaseModel):
|
|||||||
"""graph node ids"""
|
"""graph node ids"""
|
||||||
|
|
||||||
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
||||||
"""graph edge mapping"""
|
"""graph edge mapping (source node id: edges)"""
|
||||||
|
|
||||||
|
parallel_mapping: dict[str, GraphParallel] = Field(default_factory=dict)
|
||||||
|
"""graph parallel mapping (parallel id: parallel)"""
|
||||||
|
|
||||||
|
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
|
||||||
|
"""graph node parallel mapping (node id: parallel id)"""
|
||||||
|
|
||||||
run_state: GraphState
|
run_state: GraphState
|
||||||
"""graph run state"""
|
"""graph run state"""
|
||||||
@ -139,6 +163,16 @@ class Graph(BaseModel):
|
|||||||
node_id=root_node_id
|
node_id=root_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# init parallel mapping
|
||||||
|
parallel_mapping: dict[str, GraphParallel] = {}
|
||||||
|
node_parallel_mapping: dict[str, str] = {}
|
||||||
|
cls._recursively_add_parallels(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
start_node_id=root_node_id,
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
node_parallel_mapping=node_parallel_mapping
|
||||||
|
)
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
graph = cls(
|
graph = cls(
|
||||||
root_node_id=root_node_id,
|
root_node_id=root_node_id,
|
||||||
@ -146,7 +180,9 @@ class Graph(BaseModel):
|
|||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
run_state=GraphState(
|
run_state=GraphState(
|
||||||
variable_pool=variable_pool
|
variable_pool=variable_pool
|
||||||
)
|
),
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
node_parallel_mapping=node_parallel_mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
@ -173,12 +209,48 @@ class Graph(BaseModel):
|
|||||||
edge_mapping=edge_mapping,
|
edge_mapping=edge_mapping,
|
||||||
node_id=graph_edge.target_node_id
|
node_id=graph_edge.target_node_id
|
||||||
)
|
)
|
||||||
def next_node_ids(self) -> list[str]:
|
|
||||||
|
def next_node_ids(self) -> list[NextGraphNode]:
|
||||||
"""
|
"""
|
||||||
Get next node ids
|
Get next node ids
|
||||||
"""
|
"""
|
||||||
# todo
|
# get current node ids in state
|
||||||
return []
|
if not self.run_state.routes:
|
||||||
|
return [NextGraphNode(node_id=self.root_node_id)]
|
||||||
|
|
||||||
|
route_final_graph_edges: list[GraphEdge] = []
|
||||||
|
for route in self.run_state.routes[self.root_node_id]:
|
||||||
|
graph_edges = self.edge_mapping.get(route.node_id)
|
||||||
|
if not graph_edges:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for edge in graph_edges:
|
||||||
|
if edge.target_node_id not in self.run_state.routes:
|
||||||
|
route_final_graph_edges.append(edge)
|
||||||
|
|
||||||
|
next_graph_nodes = []
|
||||||
|
for route_final_graph_edge in route_final_graph_edges:
|
||||||
|
node_id = route_final_graph_edge.target_node_id
|
||||||
|
# check condition
|
||||||
|
if route_final_graph_edge.run_condition:
|
||||||
|
result = ConditionManager.get_condition_handler(
|
||||||
|
run_condition=route_final_graph_edge.run_condition
|
||||||
|
).check(
|
||||||
|
source_node_id=route_final_graph_edge.source_node_id,
|
||||||
|
target_node_id=route_final_graph_edge.target_node_id,
|
||||||
|
graph=self
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
parallel = None
|
||||||
|
if route_final_graph_edge.target_node_id in self.node_parallel_mapping:
|
||||||
|
parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]]
|
||||||
|
|
||||||
|
next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
|
||||||
|
|
||||||
|
return next_graph_nodes
|
||||||
|
|
||||||
def add_extra_edge(self, source_node_id: str,
|
def add_extra_edge(self, source_node_id: str,
|
||||||
target_node_id: str,
|
target_node_id: str,
|
||||||
@ -222,3 +294,163 @@ class Graph(BaseModel):
|
|||||||
leaf_node_ids.append(node_id)
|
leaf_node_ids.append(node_id)
|
||||||
|
|
||||||
return leaf_node_ids
|
return leaf_node_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _recursively_add_parallels(cls,
|
||||||
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
start_node_id: str,
|
||||||
|
parallel_mapping: dict[str, GraphParallel],
|
||||||
|
node_parallel_mapping: dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Recursively add parallel ids
|
||||||
|
|
||||||
|
:param edge_mapping: edge mapping
|
||||||
|
:param start_node_id: start from node id
|
||||||
|
:param parallel_mapping: parallel mapping
|
||||||
|
:param node_parallel_mapping: node parallel mapping
|
||||||
|
"""
|
||||||
|
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||||
|
if len(target_node_edges) > 1:
|
||||||
|
# fetch all node ids in current parallels
|
||||||
|
parallel_node_ids = [graph_edge.target_node_id
|
||||||
|
for graph_edge in target_node_edges if graph_edge.run_condition is not None]
|
||||||
|
|
||||||
|
# any target node id in node_parallel_mapping
|
||||||
|
if parallel_node_ids:
|
||||||
|
# all parallel_node_ids in node_parallel_mapping
|
||||||
|
parent_parallel_id = None
|
||||||
|
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
|
||||||
|
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
|
||||||
|
|
||||||
|
parallel = GraphParallel(parent_parallel_id=parent_parallel_id)
|
||||||
|
parallel_mapping[parallel.id] = parallel
|
||||||
|
|
||||||
|
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
parallel_node_ids=parallel_node_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
node_parallel_mapping.update({node_id: parallel.id for node_id in in_branch_node_ids})
|
||||||
|
|
||||||
|
for graph_edge in target_node_edges:
|
||||||
|
cls._recursively_add_parallels(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
start_node_id=graph_edge.target_node_id,
|
||||||
|
parallel_mapping=parallel_mapping,
|
||||||
|
node_parallel_mapping=node_parallel_mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _recursively_add_parallel_node_ids(cls,
|
||||||
|
branch_node_ids: list[str],
|
||||||
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
merge_node_id: str,
|
||||||
|
start_node_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Recursively add node ids
|
||||||
|
|
||||||
|
:param branch_node_ids: in branch node ids
|
||||||
|
:param edge_mapping: edge mapping
|
||||||
|
:param merge_node_id: merge node id
|
||||||
|
:param start_node_id: start node id
|
||||||
|
"""
|
||||||
|
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||||
|
if (graph_edge.target_node_id != merge_node_id
|
||||||
|
and graph_edge.target_node_id not in branch_node_ids):
|
||||||
|
branch_node_ids.append(graph_edge.target_node_id)
|
||||||
|
cls._recursively_add_parallel_node_ids(
|
||||||
|
branch_node_ids=branch_node_ids,
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
merge_node_id=merge_node_id,
|
||||||
|
start_node_id=graph_edge.target_node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fetch_all_node_ids_in_parallels(cls,
|
||||||
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
parallel_node_ids: list[str]) -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Fetch all node ids in parallels
|
||||||
|
"""
|
||||||
|
routes_node_ids: dict[str, list[str]] = {}
|
||||||
|
for parallel_node_id in parallel_node_ids:
|
||||||
|
routes_node_ids[parallel_node_id] = []
|
||||||
|
|
||||||
|
# fetch routes node ids
|
||||||
|
cls._recursively_fetch_routes(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
start_node_id=parallel_node_id,
|
||||||
|
routes_node_ids=routes_node_ids[parallel_node_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
# fetch leaf node ids from routes node ids
|
||||||
|
leaf_node_ids: dict[str, list[str]] = {}
|
||||||
|
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||||
|
for branch_node_id, node_ids in routes_node_ids.items():
|
||||||
|
for node_id in node_ids:
|
||||||
|
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||||
|
if branch_node_id not in leaf_node_ids:
|
||||||
|
leaf_node_ids[branch_node_id] = []
|
||||||
|
|
||||||
|
leaf_node_ids[branch_node_id].append(node_id)
|
||||||
|
|
||||||
|
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||||
|
if branch_node_id != branch_node_id2 and node_id in inner_route2:
|
||||||
|
if node_id not in merge_branch_node_ids:
|
||||||
|
merge_branch_node_ids[node_id] = []
|
||||||
|
|
||||||
|
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||||
|
|
||||||
|
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||||
|
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||||
|
|
||||||
|
branches_merge_node_ids: dict[str, str] = {}
|
||||||
|
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||||
|
if len(branch_node_ids) <= 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for branch_node_id in branch_node_ids:
|
||||||
|
if branch_node_id in branches_merge_node_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
branches_merge_node_ids[branch_node_id] = node_id
|
||||||
|
|
||||||
|
in_branch_node_ids: dict[str, list[str]] = {}
|
||||||
|
for branch_node_id, node_ids in routes_node_ids.items():
|
||||||
|
in_branch_node_ids[branch_node_id] = [branch_node_id]
|
||||||
|
if branch_node_id not in branches_merge_node_ids:
|
||||||
|
# all node ids in current branch is in this thread
|
||||||
|
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||||
|
else:
|
||||||
|
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||||
|
# fetch all node ids from branch_node_id and merge_node_id
|
||||||
|
cls._recursively_add_parallel_node_ids(
|
||||||
|
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
merge_node_id=merge_node_id,
|
||||||
|
start_node_id=branch_node_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return in_branch_node_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _recursively_fetch_routes(cls,
|
||||||
|
edge_mapping: dict[str, list[GraphEdge]],
|
||||||
|
start_node_id: str,
|
||||||
|
routes_node_ids: list[str]) -> None:
|
||||||
|
"""
|
||||||
|
Recursively fetch route
|
||||||
|
"""
|
||||||
|
if start_node_id not in edge_mapping:
|
||||||
|
return
|
||||||
|
|
||||||
|
for graph_edge in edge_mapping[start_node_id]:
|
||||||
|
# find next node ids
|
||||||
|
if graph_edge.target_node_id not in routes_node_ids:
|
||||||
|
routes_node_ids.append(graph_edge.target_node_id)
|
||||||
|
|
||||||
|
cls._recursively_fetch_routes(
|
||||||
|
edge_mapping=edge_mapping,
|
||||||
|
start_node_id=graph_edge.target_node_id,
|
||||||
|
routes_node_ids=routes_node_ids
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user