mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 05:25:54 +08:00
optimize graph
This commit is contained in:
parent
8375517ccd
commit
0f19b2a986
28
api/core/workflow/entities/workflow_runtime_state.py
Normal file
28
api/core/workflow/entities/workflow_runtime_state.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
|
||||||
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRuntimeState(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
app_id: str
|
||||||
|
workflow_id: str
|
||||||
|
workflow_type: WorkflowType
|
||||||
|
user_id: str
|
||||||
|
user_from: UserFrom
|
||||||
|
variable_pool: VariablePool
|
||||||
|
invoke_from: InvokeFrom
|
||||||
|
graph: Graph
|
||||||
|
call_depth: int
|
||||||
|
start_at: float
|
||||||
|
|
||||||
|
total_tokens: int = 0
|
||||||
|
node_run_steps: int = 0
|
||||||
|
|
||||||
|
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
0
api/core/workflow/graph_engine/__init__.py
Normal file
0
api/core/workflow/graph_engine/__init__.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
|
class RunConditionHandler(ABC):
|
||||||
|
def __init__(self, condition: RunCondition):
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check(self,
|
||||||
|
graph_node: "GraphNode",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
predecessor_node_result: NodeRunResult) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the condition can be executed
|
||||||
|
|
||||||
|
:param graph_node: graph node
|
||||||
|
:param graph_runtime_state: graph runtime state
|
||||||
|
:param predecessor_node_result: predecessor node result
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,25 @@
|
|||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
|
||||||
|
|
||||||
|
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||||
|
|
||||||
|
def check(self,
|
||||||
|
graph_node: "GraphNode",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
predecessor_node_result: NodeRunResult) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the condition can be executed
|
||||||
|
|
||||||
|
:param graph_node: graph node
|
||||||
|
:param graph_runtime_state: graph runtime state
|
||||||
|
:param predecessor_node_result: predecessor node result
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if not self.condition.branch_identify:
|
||||||
|
raise Exception("Branch identify is required")
|
||||||
|
|
||||||
|
if not predecessor_node_result.edge_source_handle:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.condition.branch_identify == predecessor_node_result.edge_source_handle
|
@ -0,0 +1,31 @@
|
|||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||||
|
def check(self,
|
||||||
|
graph_node: "GraphNode",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
predecessor_node_result: NodeRunResult) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the condition can be executed
|
||||||
|
|
||||||
|
:param graph_node: graph node
|
||||||
|
:param graph_runtime_state: graph runtime state
|
||||||
|
:param predecessor_node_result: predecessor node result
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if not self.condition.conditions:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# process condition
|
||||||
|
condition_processor = ConditionProcessor()
|
||||||
|
compare_result, _ = condition_processor.process(
|
||||||
|
variable_pool=graph_runtime_state.variable_pool,
|
||||||
|
logical_operator="and",
|
||||||
|
conditions=self.condition.conditions
|
||||||
|
)
|
||||||
|
|
||||||
|
return compare_result
|
||||||
|
|
@ -0,0 +1,19 @@
|
|||||||
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||||
|
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||||
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionManager:
|
||||||
|
@staticmethod
|
||||||
|
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
|
||||||
|
"""
|
||||||
|
Get condition handler
|
||||||
|
|
||||||
|
:param run_condition: run condition
|
||||||
|
:return: condition handler
|
||||||
|
"""
|
||||||
|
if run_condition.type == "branch_identify":
|
||||||
|
return BranchIdentifyRunConditionHandler(run_condition)
|
||||||
|
else:
|
||||||
|
return ConditionRunConditionHandlerHandler(run_condition)
|
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
@ -1,20 +1,10 @@
|
|||||||
from collections.abc import Callable
|
from typing import Optional
|
||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.utils.condition.entities import Condition
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
class RunCondition(BaseModel):
|
|
||||||
type: Literal["branch_identify", "condition"]
|
|
||||||
"""condition type"""
|
|
||||||
|
|
||||||
branch_identify: Optional[str] = None
|
|
||||||
"""branch identify, required when type is branch_identify"""
|
|
||||||
|
|
||||||
conditions: Optional[list[Condition]] = None
|
|
||||||
"""conditions to run the node, required when type is condition"""
|
|
||||||
|
|
||||||
|
|
||||||
class GraphNode(BaseModel):
|
class GraphNode(BaseModel):
|
||||||
@ -33,9 +23,6 @@ class GraphNode(BaseModel):
|
|||||||
run_condition: Optional[RunCondition] = None
|
run_condition: Optional[RunCondition] = None
|
||||||
"""condition to run the node"""
|
"""condition to run the node"""
|
||||||
|
|
||||||
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
|
|
||||||
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
|
|
||||||
|
|
||||||
node_config: dict
|
node_config: dict
|
||||||
"""original node config"""
|
"""original node config"""
|
||||||
|
|
||||||
@ -48,9 +35,22 @@ class GraphNode(BaseModel):
|
|||||||
def add_child(self, node_id: str) -> None:
|
def add_child(self, node_id: str) -> None:
|
||||||
self.descendant_node_ids.append(node_id)
|
self.descendant_node_ids.append(node_id)
|
||||||
|
|
||||||
|
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
|
||||||
|
"""
|
||||||
|
Get run condition handler
|
||||||
|
|
||||||
|
:return: run condition handler
|
||||||
|
"""
|
||||||
|
if not self.run_condition:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ConditionManager.get_condition_handler(
|
||||||
|
run_condition=self.run_condition
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
graph_nodes: dict[str, GraphNode] = {}
|
graph_nodes: dict[str, GraphNode] = Field(default_factory=dict)
|
||||||
"""graph nodes"""
|
"""graph nodes"""
|
||||||
|
|
||||||
root_node: GraphNode
|
root_node: GraphNode
|
||||||
@ -65,8 +65,12 @@ class Graph(BaseModel):
|
|||||||
:param run_condition: run condition when root node parent is iteration/loop
|
:param run_condition: run condition when root node parent is iteration/loop
|
||||||
:return: graph
|
:return: graph
|
||||||
"""
|
"""
|
||||||
|
node_id = root_node_config.get('id')
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("Graph root node id is required")
|
||||||
|
|
||||||
root_node = GraphNode(
|
root_node = GraphNode(
|
||||||
id=root_node_config.get('id'),
|
id=node_id,
|
||||||
parent_id=root_node_config.get('parentId'),
|
parent_id=root_node_config.get('parentId'),
|
||||||
node_config=root_node_config,
|
node_config=root_node_config,
|
||||||
run_condition=run_condition
|
run_condition=run_condition
|
||||||
@ -74,15 +78,14 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
graph = cls(root_node=root_node)
|
graph = cls(root_node=root_node)
|
||||||
|
|
||||||
# TODO parse run_condition to run_condition_callback
|
graph._add_graph_node(graph.root_node)
|
||||||
|
|
||||||
graph.add_graph_node(graph.root_node)
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
def add_edge(self, edge_config: dict,
|
def add_edge(self, edge_config: dict,
|
||||||
source_node_config: dict,
|
source_node_config: dict,
|
||||||
target_node_config: dict,
|
target_node_config: dict,
|
||||||
target_node_sub_graph: Optional["Graph"] = None) -> None:
|
target_node_sub_graph: Optional["Graph"] = None,
|
||||||
|
run_condition: Optional[RunCondition] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Add edge to the graph
|
Add edge to the graph
|
||||||
|
|
||||||
@ -90,6 +93,7 @@ class Graph(BaseModel):
|
|||||||
:param source_node_config: source node config
|
:param source_node_config: source node config
|
||||||
:param target_node_config: target node config
|
:param target_node_config: target node config
|
||||||
:param target_node_sub_graph: sub graph
|
:param target_node_sub_graph: sub graph
|
||||||
|
:param run_condition: run condition
|
||||||
"""
|
"""
|
||||||
source_node_id = source_node_config.get('id')
|
source_node_id = source_node_config.get('id')
|
||||||
if not source_node_id:
|
if not source_node_id:
|
||||||
@ -105,48 +109,25 @@ class Graph(BaseModel):
|
|||||||
source_node = self.graph_nodes[source_node_id]
|
source_node = self.graph_nodes[source_node_id]
|
||||||
source_node.add_child(target_node_id)
|
source_node.add_child(target_node_id)
|
||||||
|
|
||||||
# if run_conditions:
|
|
||||||
# run_condition_callback = lambda: all()
|
|
||||||
|
|
||||||
|
|
||||||
if target_node_id not in self.graph_nodes:
|
if target_node_id not in self.graph_nodes:
|
||||||
run_condition = None # todo
|
|
||||||
run_condition_callback = None # todo
|
|
||||||
|
|
||||||
target_graph_node = GraphNode(
|
target_graph_node = GraphNode(
|
||||||
id=target_node_id,
|
id=target_node_id,
|
||||||
parent_id=source_node_config.get('parentId'),
|
parent_id=source_node_config.get('parentId'),
|
||||||
predecessor_node_id=source_node_id,
|
predecessor_node_id=source_node_id,
|
||||||
node_config=target_node_config,
|
node_config=target_node_config,
|
||||||
run_condition=run_condition,
|
run_condition=run_condition,
|
||||||
run_condition_callback=run_condition_callback,
|
|
||||||
source_edge_config=edge_config,
|
source_edge_config=edge_config,
|
||||||
sub_graph=target_node_sub_graph
|
sub_graph=target_node_sub_graph
|
||||||
)
|
)
|
||||||
|
|
||||||
self.add_graph_node(target_graph_node)
|
self._add_graph_node(target_graph_node)
|
||||||
else:
|
else:
|
||||||
target_node = self.graph_nodes[target_node_id]
|
target_node = self.graph_nodes[target_node_id]
|
||||||
target_node.predecessor_node_id = source_node_id
|
target_node.predecessor_node_id = source_node_id
|
||||||
target_node.run_conditions = run_conditions
|
target_node.run_condition = run_condition
|
||||||
target_node.run_condition_callback = run_condition_callback
|
|
||||||
target_node.source_edge_config = edge_config
|
target_node.source_edge_config = edge_config
|
||||||
target_node.sub_graph = target_node_sub_graph
|
target_node.sub_graph = target_node_sub_graph
|
||||||
|
|
||||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
|
||||||
"""
|
|
||||||
Add graph node to the graph
|
|
||||||
|
|
||||||
:param graph_node: graph node
|
|
||||||
"""
|
|
||||||
if graph_node.id in self.graph_nodes:
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(self.graph_nodes) == 0:
|
|
||||||
self.root_node = graph_node
|
|
||||||
|
|
||||||
self.graph_nodes[graph_node.id] = graph_node
|
|
||||||
|
|
||||||
def get_root_node(self) -> Optional[GraphNode]:
|
def get_root_node(self) -> Optional[GraphNode]:
|
||||||
"""
|
"""
|
||||||
Get root node of the graph
|
Get root node of the graph
|
||||||
@ -169,14 +150,28 @@ class Graph(BaseModel):
|
|||||||
if not graph_node.descendant_node_ids:
|
if not graph_node.descendant_node_ids:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
descendants_graph = Graph()
|
descendants_graph = Graph(root_node=graph_node)
|
||||||
descendants_graph.add_graph_node(graph_node)
|
descendants_graph._add_graph_node(graph_node)
|
||||||
|
|
||||||
for child_node_id in graph_node.descendant_node_ids:
|
for child_node_id in graph_node.descendant_node_ids:
|
||||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||||
|
|
||||||
return descendants_graph
|
return descendants_graph
|
||||||
|
|
||||||
|
def _add_graph_node(self, graph_node: GraphNode) -> None:
|
||||||
|
"""
|
||||||
|
Add graph node to the graph
|
||||||
|
|
||||||
|
:param graph_node: graph node
|
||||||
|
"""
|
||||||
|
if graph_node.id in self.graph_nodes:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(self.graph_nodes) == 0:
|
||||||
|
self.root_node = graph_node
|
||||||
|
|
||||||
|
self.graph_nodes[graph_node.id] = graph_node
|
||||||
|
|
||||||
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Add descendants graph nodes
|
Add descendants graph nodes
|
||||||
@ -188,7 +183,7 @@ class Graph(BaseModel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
graph_node = self.graph_nodes[node_id]
|
graph_node = self.graph_nodes[node_id]
|
||||||
descendants_graph.add_graph_node(graph_node)
|
descendants_graph._add_graph_node(graph_node)
|
||||||
|
|
||||||
for child_node_id in graph_node.descendant_node_ids:
|
for child_node_id in graph_node.descendant_node_ids:
|
||||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
@ -0,0 +1,26 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
|
||||||
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRuntimeState(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
app_id: str
|
||||||
|
user_id: str
|
||||||
|
user_from: UserFrom
|
||||||
|
invoke_from: InvokeFrom
|
||||||
|
call_depth: int
|
||||||
|
|
||||||
|
graph: Graph
|
||||||
|
variable_pool: VariablePool
|
||||||
|
|
||||||
|
start_at: Optional[float] = None
|
||||||
|
total_tokens: int = 0
|
||||||
|
node_run_steps: int = 0
|
||||||
|
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
16
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
16
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.workflow.utils.condition.entities import Condition
|
||||||
|
|
||||||
|
|
||||||
|
class RunCondition(BaseModel):
|
||||||
|
type: Literal["branch_identify", "condition"]
|
||||||
|
"""condition type"""
|
||||||
|
|
||||||
|
branch_identify: Optional[str] = None
|
||||||
|
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||||
|
|
||||||
|
conditions: Optional[list[Condition]] = None
|
||||||
|
"""conditions to run the node, required when type is condition"""
|
@ -1,55 +1,11 @@
|
|||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
|
||||||
from core.workflow.graph import Graph, GraphNode
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
from core.workflow.nodes.base_node import UserFrom
|
|
||||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
|
||||||
|
|
||||||
|
|
||||||
class RuntimeNode(BaseModel):
|
|
||||||
class Status(Enum):
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
SUCCESS = "success"
|
|
||||||
FAILED = "failed"
|
|
||||||
PAUSED = "paused"
|
|
||||||
|
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
||||||
"""random id for current runtime node"""
|
|
||||||
|
|
||||||
graph_node: GraphNode
|
|
||||||
"""graph node"""
|
|
||||||
|
|
||||||
node_run_result: Optional[NodeRunResult] = None
|
|
||||||
"""node run result"""
|
|
||||||
|
|
||||||
status: Status = Status.PENDING
|
|
||||||
"""node status"""
|
|
||||||
|
|
||||||
start_at: Optional[datetime] = None
|
|
||||||
"""start time"""
|
|
||||||
|
|
||||||
paused_at: Optional[datetime] = None
|
|
||||||
"""paused time"""
|
|
||||||
|
|
||||||
finished_at: Optional[datetime] = None
|
|
||||||
"""finished time"""
|
|
||||||
|
|
||||||
failed_reason: Optional[str] = None
|
|
||||||
"""failed reason"""
|
|
||||||
|
|
||||||
paused_by: Optional[str] = None
|
|
||||||
"""paused by"""
|
|
||||||
|
|
||||||
predecessor_runtime_node_id: Optional[str] = None
|
|
||||||
"""predecessor runtime node id"""
|
|
||||||
|
|
||||||
|
|
||||||
class RuntimeGraph(BaseModel):
|
class RuntimeGraph(BaseModel):
|
||||||
@ -80,22 +36,3 @@ class RuntimeGraph(BaseModel):
|
|||||||
runtime_node.status = RuntimeNode.Status.PAUSED
|
runtime_node.status = RuntimeNode.Status.PAUSED
|
||||||
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
runtime_node.paused_by = paused_by
|
runtime_node.paused_by = paused_by
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRuntimeState(BaseModel):
|
|
||||||
tenant_id: str
|
|
||||||
app_id: str
|
|
||||||
workflow_id: str
|
|
||||||
workflow_type: WorkflowType
|
|
||||||
user_id: str
|
|
||||||
user_from: UserFrom
|
|
||||||
variable_pool: VariablePool
|
|
||||||
invoke_from: InvokeFrom
|
|
||||||
graph: Graph
|
|
||||||
call_depth: int
|
|
||||||
start_at: float
|
|
||||||
|
|
||||||
total_tokens: int = 0
|
|
||||||
node_run_steps: int = 0
|
|
||||||
|
|
||||||
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
|
48
api/core/workflow/graph_engine/entities/runtime_node.py
Normal file
48
api/core/workflow/graph_engine/entities/runtime_node.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.graph_engine.entities.graph import GraphNode
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeNode(BaseModel):
|
||||||
|
class Status(Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILED = "failed"
|
||||||
|
PAUSED = "paused"
|
||||||
|
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
"""random id for current runtime node"""
|
||||||
|
|
||||||
|
graph_node: GraphNode
|
||||||
|
"""graph node"""
|
||||||
|
|
||||||
|
node_run_result: Optional[NodeRunResult] = None
|
||||||
|
"""node run result"""
|
||||||
|
|
||||||
|
status: Status = Status.PENDING
|
||||||
|
"""node status"""
|
||||||
|
|
||||||
|
start_at: Optional[datetime] = None
|
||||||
|
"""start time"""
|
||||||
|
|
||||||
|
paused_at: Optional[datetime] = None
|
||||||
|
"""paused time"""
|
||||||
|
|
||||||
|
finished_at: Optional[datetime] = None
|
||||||
|
"""finished time"""
|
||||||
|
|
||||||
|
failed_reason: Optional[str] = None
|
||||||
|
"""failed reason"""
|
||||||
|
|
||||||
|
paused_by: Optional[str] = None
|
||||||
|
"""paused by"""
|
||||||
|
|
||||||
|
predecessor_runtime_node_id: Optional[str] = None
|
||||||
|
"""predecessor runtime node id"""
|
45
api/core/workflow/graph_engine/graph_engine.py
Normal file
45
api/core/workflow/graph_engine/graph_engine.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEngine:
|
||||||
|
def __init__(self, tenant_id: str,
|
||||||
|
app_id: str,
|
||||||
|
user_id: str,
|
||||||
|
user_from: UserFrom,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
call_depth: int,
|
||||||
|
graph: Graph,
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
callbacks: list[BaseWorkflowCallback]) -> None:
|
||||||
|
self.graph_runtime_state = GraphRuntimeState(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
user_id=user_id,
|
||||||
|
user_from=user_from,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
call_depth=call_depth,
|
||||||
|
graph=graph,
|
||||||
|
variable_pool=variable_pool
|
||||||
|
)
|
||||||
|
|
||||||
|
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||||
|
self.max_execution_steps = cast(int, max_execution_steps)
|
||||||
|
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||||
|
self.max_execution_time = cast(int, max_execution_time)
|
||||||
|
|
||||||
|
self.callbacks = callbacks
|
||||||
|
|
||||||
|
def run(self) -> Generator:
|
||||||
|
self.graph_runtime_state.start_at = time.perf_counter()
|
||||||
|
pass
|
@ -7,6 +7,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
|||||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||||
|
|
||||||
|
|
||||||
class UserFrom(Enum):
|
class UserFrom(Enum):
|
||||||
@ -39,7 +40,7 @@ class BaseNode(ABC):
|
|||||||
user_id: str
|
user_id: str
|
||||||
user_from: UserFrom
|
user_from: UserFrom
|
||||||
invoke_from: InvokeFrom
|
invoke_from: InvokeFrom
|
||||||
|
|
||||||
workflow_call_depth: int
|
workflow_call_depth: int
|
||||||
|
|
||||||
node_id: str
|
node_id: str
|
||||||
@ -149,7 +150,8 @@ class BaseNode(ABC):
|
|||||||
"""
|
"""
|
||||||
return self._node_type
|
return self._node_type
|
||||||
|
|
||||||
class BaseIterationNode(BaseNode):
|
|
||||||
|
class BaseIterationNode(BaseNode, IterableNodeMixin):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||||
"""
|
"""
|
||||||
@ -174,7 +176,7 @@ class BaseIterationNode(BaseNode):
|
|||||||
:return: next node id
|
:return: next node id
|
||||||
"""
|
"""
|
||||||
return self._get_next_iteration(variable_pool, state)
|
return self._get_next_iteration(variable_pool, state)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||||
"""
|
"""
|
||||||
|
14
api/core/workflow/nodes/iterable_node.py
Normal file
14
api/core/workflow/nodes/iterable_node.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.workflow.utils.condition.entities import Condition
|
||||||
|
|
||||||
|
|
||||||
|
class IterableNodeMixin(ABC):
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||||
|
"""
|
||||||
|
Get conditions.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -1,4 +1,4 @@
|
|||||||
from typing import cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||||
@ -6,6 +6,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
|||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseIterationNode
|
from core.workflow.nodes.base_node import BaseIterationNode
|
||||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||||
|
from core.workflow.utils.condition.entities import Condition
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -116,4 +117,20 @@ class IterationNode(BaseIterationNode):
|
|||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
'input_selector': node_data.iterator_selector,
|
'input_selector': node_data.iterator_selector,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||||
|
"""
|
||||||
|
Get conditions.
|
||||||
|
"""
|
||||||
|
node_id = node_config.get('id')
|
||||||
|
if not node_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [Condition(
|
||||||
|
variable_selector=[node_id, 'index'],
|
||||||
|
comparison_operator="≤",
|
||||||
|
value_type="value_selector",
|
||||||
|
value_selector=node_config.get('data', {}).get('iterator_selector')
|
||||||
|
)]
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.nodes.base_node import BaseIterationNode
|
from core.workflow.nodes.base_node import BaseIterationNode
|
||||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||||
|
from core.workflow.utils.condition.entities import Condition
|
||||||
|
|
||||||
|
|
||||||
class LoopNode(BaseIterationNode):
|
class LoopNode(BaseIterationNode):
|
||||||
@ -18,3 +21,21 @@ class LoopNode(BaseIterationNode):
|
|||||||
"""
|
"""
|
||||||
Get next iteration start node id based on the graph.
|
Get next iteration start node id based on the graph.
|
||||||
"""
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||||
|
"""
|
||||||
|
Get conditions.
|
||||||
|
"""
|
||||||
|
node_id = node_config.get('id')
|
||||||
|
if not node_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# TODO waiting for implementation
|
||||||
|
return [Condition(
|
||||||
|
variable_selector=[node_id, 'index'],
|
||||||
|
comparison_operator="≤",
|
||||||
|
value_type="value_selector",
|
||||||
|
value_selector=[]
|
||||||
|
)]
|
||||||
|
@ -14,4 +14,6 @@ class Condition(BaseModel):
|
|||||||
# for number
|
# for number
|
||||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||||
]
|
]
|
||||||
|
value_type: Literal["string", "value_selector"] = "string"
|
||||||
value: Optional[str] = None
|
value: Optional[str] = None
|
||||||
|
value_selector: Optional[list[str]] = None
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
|
||||||
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
|
|
||||||
from core.workflow.graph import GraphNode
|
|
||||||
|
|
||||||
|
|
||||||
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
|
|
||||||
graph_node: GraphNode,
|
|
||||||
# TODO cycle_state optional
|
|
||||||
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
|
|
||||||
if not graph_node.source_edge_config:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not graph_node.source_edge_config.get('sourceHandle'):
|
|
||||||
return True
|
|
||||||
|
|
||||||
source_handle = predecessor_node_run_result.edge_source_handle \
|
|
||||||
if predecessor_node_run_result else None
|
|
||||||
|
|
||||||
return (source_handle is not None
|
|
||||||
and graph_node.source_edge_config.get('sourceHandle') == source_handle)
|
|
@ -24,7 +24,12 @@ class ConditionProcessor:
|
|||||||
variable_selector=condition.variable_selector
|
variable_selector=condition.variable_selector
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_value = condition.value
|
if condition.value_type == "value_selector":
|
||||||
|
expected_value = variable_pool.get_variable_value(
|
||||||
|
variable_selector=condition.value_selector
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expected_value = condition.value
|
||||||
|
|
||||||
input_conditions.append({
|
input_conditions.append({
|
||||||
"actual_value": actual_value,
|
"actual_value": actual_value,
|
||||||
@ -208,7 +213,7 @@ class ConditionProcessor:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||||
"""
|
"""
|
||||||
Assert equal
|
Assert equal
|
||||||
:param actual_value: actual value
|
:param actual_value: actual value
|
||||||
@ -230,7 +235,7 @@ class ConditionProcessor:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||||
"""
|
"""
|
||||||
Assert not equal
|
Assert not equal
|
||||||
:param actual_value: actual value
|
:param actual_value: actual value
|
||||||
@ -252,7 +257,7 @@ class ConditionProcessor:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||||
"""
|
"""
|
||||||
Assert greater than
|
Assert greater than
|
||||||
:param actual_value: actual value
|
:param actual_value: actual value
|
||||||
@ -274,7 +279,7 @@ class ConditionProcessor:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||||
"""
|
"""
|
||||||
Assert less than
|
Assert less than
|
||||||
:param actual_value: actual value
|
:param actual_value: actual value
|
||||||
@ -296,7 +301,7 @@ class ConditionProcessor:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||||
"""
|
"""
|
||||||
Assert greater than or equal
|
Assert greater than or equal
|
||||||
:param actual_value: actual value
|
:param actual_value: actual value
|
||||||
@ -318,7 +323,7 @@ class ConditionProcessor:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||||
"""
|
"""
|
||||||
Assert less than or equal
|
Assert less than or equal
|
||||||
:param actual_value: actual value
|
:param actual_value: actual value
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Generator
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
@ -12,15 +13,18 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
|||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||||
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
|
from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||||
from core.workflow.nodes.code.code_node import CodeNode
|
from core.workflow.nodes.code.code_node import CodeNode
|
||||||
from core.workflow.nodes.end.end_node import EndNode
|
from core.workflow.nodes.end.end_node import EndNode
|
||||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
|
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||||
from core.workflow.nodes.iteration.entities import IterationState
|
from core.workflow.nodes.iteration.entities import IterationState
|
||||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||||
@ -36,7 +40,6 @@ from extensions.ext_database import db
|
|||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecutionStatus,
|
WorkflowNodeExecutionStatus,
|
||||||
WorkflowType,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
node_classes = {
|
node_classes = {
|
||||||
@ -60,7 +63,7 @@ node_classes = {
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEngineManager:
|
class WorkflowEntry:
|
||||||
def run_workflow(self, workflow: Workflow,
|
def run_workflow(self, workflow: Workflow,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
user_from: UserFrom,
|
user_from: UserFrom,
|
||||||
@ -69,7 +72,7 @@ class WorkflowEngineManager:
|
|||||||
user_inputs: dict,
|
user_inputs: dict,
|
||||||
system_inputs: dict[SystemVariable, Any],
|
system_inputs: dict[SystemVariable, Any],
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
variable_pool: Optional[VariablePool] = None) -> None:
|
variable_pool: Optional[VariablePool] = None) -> Generator:
|
||||||
"""
|
"""
|
||||||
:param workflow: Workflow instance
|
:param workflow: Workflow instance
|
||||||
:param user_id: user id
|
:param user_id: user id
|
||||||
@ -102,25 +105,25 @@ class WorkflowEngineManager:
|
|||||||
user_inputs=user_inputs
|
user_inputs=user_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
# fetch max call depth
|
# init graph
|
||||||
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
|
graph = self._init_graph(
|
||||||
workflow_call_max_depth = cast(int, workflow_call_max_depth)
|
graph_config=graph_config
|
||||||
if call_depth > workflow_call_max_depth:
|
)
|
||||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
|
||||||
|
|
||||||
# init workflow runtime state
|
if not graph:
|
||||||
workflow_runtime_state = WorkflowRuntimeState(
|
raise ValueError('graph not found in workflow')
|
||||||
|
|
||||||
|
# init workflow run state
|
||||||
|
graph_engine = GraphEngine(
|
||||||
tenant_id=workflow.tenant_id,
|
tenant_id=workflow.tenant_id,
|
||||||
app_id=workflow.app_id,
|
app_id=workflow.app_id,
|
||||||
workflow_id=workflow.id,
|
|
||||||
workflow_type=WorkflowType.value_of(workflow.type),
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_from=user_from,
|
user_from=user_from,
|
||||||
variable_pool=variable_pool,
|
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
graph=graph_config,
|
|
||||||
call_depth=call_depth,
|
call_depth=call_depth,
|
||||||
start_at=time.perf_counter()
|
graph=graph,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
# init workflow run
|
# init workflow run
|
||||||
@ -130,11 +133,7 @@ class WorkflowEngineManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# run workflow
|
# run workflow
|
||||||
self._run_workflow(
|
rst = graph_engine.run()
|
||||||
graph_config=graph_config,
|
|
||||||
workflow_runtime_state=workflow_runtime_state,
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
except WorkflowRunFailedError as e:
|
except WorkflowRunFailedError as e:
|
||||||
self._workflow_run_failed(
|
self._workflow_run_failed(
|
||||||
error=e.error,
|
error=e.error,
|
||||||
@ -151,6 +150,8 @@ class WorkflowEngineManager:
|
|||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return rst
|
||||||
|
|
||||||
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
|
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
|
||||||
"""
|
"""
|
||||||
Initialize graph
|
Initialize graph
|
||||||
@ -259,30 +260,59 @@ class WorkflowEngineManager:
|
|||||||
|
|
||||||
sub_graph: Optional[Graph] = None
|
sub_graph: Optional[Graph] = None
|
||||||
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
|
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
|
||||||
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]:
|
target_node_cls = None
|
||||||
|
if target_node_type:
|
||||||
|
target_node_cls = node_classes.get(target_node_type)
|
||||||
|
if not target_node_cls:
|
||||||
|
raise Exception(f'Node class not found for node type: {target_node_type}')
|
||||||
|
|
||||||
|
if target_node_cls and issubclass(target_node_cls, IterableNodeMixin):
|
||||||
# find iteration/loop sub nodes that have no predecessor node
|
# find iteration/loop sub nodes that have no predecessor node
|
||||||
|
sub_graph_root_node_config = None
|
||||||
for root_node_config in root_node_configs:
|
for root_node_config in root_node_configs:
|
||||||
if root_node_config.get('parentId') == target_node_id:
|
if root_node_config.get('parentId') == target_node_id:
|
||||||
# create sub graph
|
sub_graph_root_node_config = root_node_config
|
||||||
sub_graph = Graph.init(
|
|
||||||
root_node_config=root_node_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self._recursively_add_edges(
|
|
||||||
graph=sub_graph,
|
|
||||||
source_node_config=root_node_config,
|
|
||||||
edges_mapping=edges_mapping,
|
|
||||||
nodes_mapping=nodes_mapping,
|
|
||||||
root_node_configs=root_node_configs
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if sub_graph_root_node_config:
|
||||||
|
# create sub graph run condition
|
||||||
|
iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls)
|
||||||
|
sub_graph_run_condition = RunCondition(
|
||||||
|
type='condition',
|
||||||
|
conditions=iterable_node_cls.get_conditions(
|
||||||
|
node_config=target_node_config
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# create sub graph
|
||||||
|
sub_graph = Graph.init(
|
||||||
|
root_node_config=sub_graph_root_node_config,
|
||||||
|
run_condition=sub_graph_run_condition
|
||||||
|
)
|
||||||
|
|
||||||
|
self._recursively_add_edges(
|
||||||
|
graph=sub_graph,
|
||||||
|
source_node_config=sub_graph_root_node_config,
|
||||||
|
edges_mapping=edges_mapping,
|
||||||
|
nodes_mapping=nodes_mapping,
|
||||||
|
root_node_configs=root_node_configs
|
||||||
|
)
|
||||||
|
|
||||||
|
# parse run condition
|
||||||
|
run_condition = None
|
||||||
|
if edge_config.get('sourceHandle'):
|
||||||
|
run_condition = RunCondition(
|
||||||
|
type='branch_identify',
|
||||||
|
branch_identify=edge_config.get('sourceHandle')
|
||||||
|
)
|
||||||
|
|
||||||
# add edge
|
# add edge
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
edge_config=edge_config,
|
edge_config=edge_config,
|
||||||
source_node_config=source_node_config,
|
source_node_config=source_node_config,
|
||||||
target_node_config=target_node_config,
|
target_node_config=target_node_config,
|
||||||
target_node_sub_graph=sub_graph,
|
target_node_sub_graph=sub_graph,
|
||||||
|
run_condition=run_condition
|
||||||
)
|
)
|
||||||
|
|
||||||
# recursively add edges
|
# recursively add edges
|
@ -1,5 +1,4 @@
|
|||||||
from core.workflow.graph import Graph
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
|
||||||
|
|
||||||
|
|
||||||
def test__init_graph():
|
def test__init_graph():
|
||||||
@ -217,18 +216,17 @@ def test__init_graph():
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
workflow_engine_manager = WorkflowEngineManager()
|
workflow_entry = WorkflowEntry()
|
||||||
graph = workflow_engine_manager._init_graph(
|
graph = workflow_entry._init_graph(
|
||||||
graph_config=graph_config
|
graph_config=graph_config
|
||||||
)
|
)
|
||||||
|
|
||||||
assert graph.root_node.id == "1717222650545"
|
assert graph.root_node.id == "1717222650545"
|
||||||
assert graph.root_node.source_edge_config is None
|
assert graph.root_node.source_edge_config is None
|
||||||
assert graph.root_node.target_edge_config is not None
|
|
||||||
assert graph.root_node.descendant_node_ids == ["1719481290322"]
|
assert graph.root_node.descendant_node_ids == ["1719481290322"]
|
||||||
|
|
||||||
assert graph.graph_nodes.get("1719481290322") is not None
|
assert graph.graph_nodes.get("1719481290322") is not None
|
||||||
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
|
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
|
||||||
|
|
||||||
assert graph.graph_nodes.get("llm").run_condition_callback is not None
|
assert graph.graph_nodes.get("llm").run_condition is not None
|
||||||
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None
|
assert graph.graph_nodes.get("1719481315734").run_condition is not None
|
Loading…
x
Reference in New Issue
Block a user