mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 15:35:58 +08:00
optimize graph
This commit is contained in:
parent
8375517ccd
commit
0f19b2a986
28
api/core/workflow/entities/workflow_runtime_state.py
Normal file
28
api/core/workflow/entities/workflow_runtime_state.py
Normal file
@ -0,0 +1,28 @@
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class WorkflowRuntimeState(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
variable_pool: VariablePool
|
||||
invoke_from: InvokeFrom
|
||||
graph: Graph
|
||||
call_depth: int
|
||||
start_at: float
|
||||
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
|
||||
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
0
api/core/workflow/graph_engine/__init__.py
Normal file
0
api/core/workflow/graph_engine/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self, condition: RunCondition):
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_node: "GraphNode",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
predecessor_node_result: NodeRunResult) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_node: graph node
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param predecessor_node_result: predecessor node result
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
@ -0,0 +1,25 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_node: "GraphNode",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
predecessor_node_result: NodeRunResult) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_node: graph node
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param predecessor_node_result: predecessor node result
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
if not predecessor_node_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == predecessor_node_result.edge_source_handle
|
@ -0,0 +1,31 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_node: "GraphNode",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
predecessor_node_result: NodeRunResult) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_node: graph node
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param predecessor_node_result: predecessor node result
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
compare_result, _ = condition_processor.process(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
logical_operator="and",
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
return compare_result
|
||||
|
@ -0,0 +1,19 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(run_condition)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(run_condition)
|
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
0
api/core/workflow/graph_engine/entities/__init__.py
Normal file
@ -1,20 +1,10 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
@ -33,9 +23,6 @@ class GraphNode(BaseModel):
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""condition to run the node"""
|
||||
|
||||
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
|
||||
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
|
||||
|
||||
node_config: dict
|
||||
"""original node config"""
|
||||
|
||||
@ -48,9 +35,22 @@ class GraphNode(BaseModel):
|
||||
def add_child(self, node_id: str) -> None:
|
||||
self.descendant_node_ids.append(node_id)
|
||||
|
||||
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
|
||||
"""
|
||||
Get run condition handler
|
||||
|
||||
:return: run condition handler
|
||||
"""
|
||||
if not self.run_condition:
|
||||
return None
|
||||
|
||||
return ConditionManager.get_condition_handler(
|
||||
run_condition=self.run_condition
|
||||
)
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
graph_nodes: dict[str, GraphNode] = {}
|
||||
graph_nodes: dict[str, GraphNode] = Field(default_factory=dict)
|
||||
"""graph nodes"""
|
||||
|
||||
root_node: GraphNode
|
||||
@ -65,8 +65,12 @@ class Graph(BaseModel):
|
||||
:param run_condition: run condition when root node parent is iteration/loop
|
||||
:return: graph
|
||||
"""
|
||||
node_id = root_node_config.get('id')
|
||||
if not node_id:
|
||||
raise ValueError("Graph root node id is required")
|
||||
|
||||
root_node = GraphNode(
|
||||
id=root_node_config.get('id'),
|
||||
id=node_id,
|
||||
parent_id=root_node_config.get('parentId'),
|
||||
node_config=root_node_config,
|
||||
run_condition=run_condition
|
||||
@ -74,15 +78,14 @@ class Graph(BaseModel):
|
||||
|
||||
graph = cls(root_node=root_node)
|
||||
|
||||
# TODO parse run_condition to run_condition_callback
|
||||
|
||||
graph.add_graph_node(graph.root_node)
|
||||
graph._add_graph_node(graph.root_node)
|
||||
return graph
|
||||
|
||||
def add_edge(self, edge_config: dict,
|
||||
source_node_config: dict,
|
||||
target_node_config: dict,
|
||||
target_node_sub_graph: Optional["Graph"] = None) -> None:
|
||||
target_node_sub_graph: Optional["Graph"] = None,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
"""
|
||||
Add edge to the graph
|
||||
|
||||
@ -90,6 +93,7 @@ class Graph(BaseModel):
|
||||
:param source_node_config: source node config
|
||||
:param target_node_config: target node config
|
||||
:param target_node_sub_graph: sub graph
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
source_node_id = source_node_config.get('id')
|
||||
if not source_node_id:
|
||||
@ -105,48 +109,25 @@ class Graph(BaseModel):
|
||||
source_node = self.graph_nodes[source_node_id]
|
||||
source_node.add_child(target_node_id)
|
||||
|
||||
# if run_conditions:
|
||||
# run_condition_callback = lambda: all()
|
||||
|
||||
|
||||
if target_node_id not in self.graph_nodes:
|
||||
run_condition = None # todo
|
||||
run_condition_callback = None # todo
|
||||
|
||||
target_graph_node = GraphNode(
|
||||
id=target_node_id,
|
||||
parent_id=source_node_config.get('parentId'),
|
||||
predecessor_node_id=source_node_id,
|
||||
node_config=target_node_config,
|
||||
run_condition=run_condition,
|
||||
run_condition_callback=run_condition_callback,
|
||||
source_edge_config=edge_config,
|
||||
sub_graph=target_node_sub_graph
|
||||
)
|
||||
|
||||
self.add_graph_node(target_graph_node)
|
||||
self._add_graph_node(target_graph_node)
|
||||
else:
|
||||
target_node = self.graph_nodes[target_node_id]
|
||||
target_node.predecessor_node_id = source_node_id
|
||||
target_node.run_conditions = run_conditions
|
||||
target_node.run_condition_callback = run_condition_callback
|
||||
target_node.run_condition = run_condition
|
||||
target_node.source_edge_config = edge_config
|
||||
target_node.sub_graph = target_node_sub_graph
|
||||
|
||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
"""
|
||||
Add graph node to the graph
|
||||
|
||||
:param graph_node: graph node
|
||||
"""
|
||||
if graph_node.id in self.graph_nodes:
|
||||
return
|
||||
|
||||
if len(self.graph_nodes) == 0:
|
||||
self.root_node = graph_node
|
||||
|
||||
self.graph_nodes[graph_node.id] = graph_node
|
||||
|
||||
def get_root_node(self) -> Optional[GraphNode]:
|
||||
"""
|
||||
Get root node of the graph
|
||||
@ -169,14 +150,28 @@ class Graph(BaseModel):
|
||||
if not graph_node.descendant_node_ids:
|
||||
return None
|
||||
|
||||
descendants_graph = Graph()
|
||||
descendants_graph.add_graph_node(graph_node)
|
||||
descendants_graph = Graph(root_node=graph_node)
|
||||
descendants_graph._add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
|
||||
return descendants_graph
|
||||
|
||||
def _add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
"""
|
||||
Add graph node to the graph
|
||||
|
||||
:param graph_node: graph node
|
||||
"""
|
||||
if graph_node.id in self.graph_nodes:
|
||||
return
|
||||
|
||||
if len(self.graph_nodes) == 0:
|
||||
self.root_node = graph_node
|
||||
|
||||
self.graph_nodes[graph_node.id] = graph_node
|
||||
|
||||
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
||||
"""
|
||||
Add descendants graph nodes
|
||||
@ -188,7 +183,7 @@ class Graph(BaseModel):
|
||||
return
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
descendants_graph.add_graph_node(graph_node)
|
||||
descendants_graph._add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
@ -0,0 +1,26 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
call_depth: int
|
||||
|
||||
graph: Graph
|
||||
variable_pool: VariablePool
|
||||
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
16
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
16
api/core/workflow/graph_engine/entities/run_condition.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
@ -1,55 +1,11 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph import Graph, GraphNode
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
class RuntimeNode(BaseModel):
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random id for current runtime node"""
|
||||
|
||||
graph_node: GraphNode
|
||||
"""graph node"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.PENDING
|
||||
"""node status"""
|
||||
|
||||
start_at: Optional[datetime] = None
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
predecessor_runtime_node_id: Optional[str] = None
|
||||
"""predecessor runtime node id"""
|
||||
from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RuntimeGraph(BaseModel):
|
||||
@ -80,22 +36,3 @@ class RuntimeGraph(BaseModel):
|
||||
runtime_node.status = RuntimeNode.Status.PAUSED
|
||||
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
runtime_node.paused_by = paused_by
|
||||
|
||||
|
||||
class WorkflowRuntimeState(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_type: WorkflowType
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
variable_pool: VariablePool
|
||||
invoke_from: InvokeFrom
|
||||
graph: Graph
|
||||
call_depth: int
|
||||
start_at: float
|
||||
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
|
||||
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
|
48
api/core/workflow/graph_engine/entities/runtime_node.py
Normal file
48
api/core/workflow/graph_engine/entities/runtime_node.py
Normal file
@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.graph import GraphNode
|
||||
|
||||
|
||||
class RuntimeNode(BaseModel):
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random id for current runtime node"""
|
||||
|
||||
graph_node: GraphNode
|
||||
"""graph node"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.PENDING
|
||||
"""node status"""
|
||||
|
||||
start_at: Optional[datetime] = None
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
predecessor_runtime_node_id: Optional[str] = None
|
||||
"""predecessor runtime node id"""
|
45
api/core/workflow/graph_engine/graph_engine.py
Normal file
45
api/core/workflow/graph_engine/graph_engine.py
Normal file
@ -0,0 +1,45 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
variable_pool: VariablePool,
|
||||
callbacks: list[BaseWorkflowCallback]) -> None:
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool
|
||||
)
|
||||
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
self.max_execution_steps = cast(int, max_execution_steps)
|
||||
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||
self.max_execution_time = cast(int, max_execution_time)
|
||||
|
||||
self.callbacks = callbacks
|
||||
|
||||
def run(self) -> Generator:
|
||||
self.graph_runtime_state.start_at = time.perf_counter()
|
||||
pass
|
@ -7,6 +7,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||
|
||||
|
||||
class UserFrom(Enum):
|
||||
@ -149,7 +150,8 @@ class BaseNode(ABC):
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
class BaseIterationNode(BaseNode):
|
||||
|
||||
class BaseIterationNode(BaseNode, IterableNodeMixin):
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
|
14
api/core/workflow/nodes/iterable_node.py
Normal file
14
api/core/workflow/nodes/iterable_node.py
Normal file
@ -0,0 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class IterableNodeMixin(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
Get conditions.
|
||||
"""
|
||||
raise NotImplementedError
|
@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
@ -6,6 +6,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -117,3 +118,19 @@ class IterationNode(BaseIterationNode):
|
||||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
Get conditions.
|
||||
"""
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
return []
|
||||
|
||||
return [Condition(
|
||||
variable_selector=[node_id, 'index'],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=node_config.get('data', {}).get('iterator_selector')
|
||||
)]
|
||||
|
@ -1,7 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopNode(BaseIterationNode):
|
||||
@ -18,3 +21,21 @@ class LoopNode(BaseIterationNode):
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
Get conditions.
|
||||
"""
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
return []
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [Condition(
|
||||
variable_selector=[node_id, 'index'],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=[]
|
||||
)]
|
||||
|
@ -14,4 +14,6 @@ class Condition(BaseModel):
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value_type: Literal["string", "value_selector"] = "string"
|
||||
value: Optional[str] = None
|
||||
value_selector: Optional[list[str]] = None
|
||||
|
@ -1,22 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
|
||||
from core.workflow.graph import GraphNode
|
||||
|
||||
|
||||
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
|
||||
graph_node: GraphNode,
|
||||
# TODO cycle_state optional
|
||||
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
|
||||
if not graph_node.source_edge_config:
|
||||
return False
|
||||
|
||||
if not graph_node.source_edge_config.get('sourceHandle'):
|
||||
return True
|
||||
|
||||
source_handle = predecessor_node_run_result.edge_source_handle \
|
||||
if predecessor_node_run_result else None
|
||||
|
||||
return (source_handle is not None
|
||||
and graph_node.source_edge_config.get('sourceHandle') == source_handle)
|
@ -24,7 +24,12 @@ class ConditionProcessor:
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
|
||||
expected_value = condition.value
|
||||
if condition.value_type == "value_selector":
|
||||
expected_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.value_selector
|
||||
)
|
||||
else:
|
||||
expected_value = condition.value
|
||||
|
||||
input_conditions.append({
|
||||
"actual_value": actual_value,
|
||||
@ -208,7 +213,7 @@ class ConditionProcessor:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
@ -230,7 +235,7 @@ class ConditionProcessor:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
@ -252,7 +257,7 @@ class ConditionProcessor:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
@ -274,7 +279,7 @@ class ConditionProcessor:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
@ -296,7 +301,7 @@ class ConditionProcessor:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
@ -318,7 +323,7 @@ class ConditionProcessor:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import current_app
|
||||
@ -12,15 +13,18 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
|
||||
from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||
from core.workflow.nodes.iteration.entities import IterationState
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
@ -36,7 +40,6 @@ from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
node_classes = {
|
||||
@ -60,7 +63,7 @@ node_classes = {
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEngineManager:
|
||||
class WorkflowEntry:
|
||||
def run_workflow(self, workflow: Workflow,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
@ -69,7 +72,7 @@ class WorkflowEngineManager:
|
||||
user_inputs: dict,
|
||||
system_inputs: dict[SystemVariable, Any],
|
||||
call_depth: int = 0,
|
||||
variable_pool: Optional[VariablePool] = None) -> None:
|
||||
variable_pool: Optional[VariablePool] = None) -> Generator:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
@ -102,25 +105,25 @@ class WorkflowEngineManager:
|
||||
user_inputs=user_inputs
|
||||
)
|
||||
|
||||
# fetch max call depth
|
||||
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
|
||||
workflow_call_max_depth = cast(int, workflow_call_max_depth)
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
# init workflow runtime state
|
||||
workflow_runtime_state = WorkflowRuntimeState(
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# init workflow run state
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
variable_pool=variable_pool,
|
||||
invoke_from=invoke_from,
|
||||
graph=graph_config,
|
||||
call_depth=call_depth,
|
||||
start_at=time.perf_counter()
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
@ -130,11 +133,7 @@ class WorkflowEngineManager:
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
graph_config=graph_config,
|
||||
workflow_runtime_state=workflow_runtime_state,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
rst = graph_engine.run()
|
||||
except WorkflowRunFailedError as e:
|
||||
self._workflow_run_failed(
|
||||
error=e.error,
|
||||
@ -151,6 +150,8 @@ class WorkflowEngineManager:
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
return rst
|
||||
|
||||
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
|
||||
"""
|
||||
Initialize graph
|
||||
@ -259,30 +260,59 @@ class WorkflowEngineManager:
|
||||
|
||||
sub_graph: Optional[Graph] = None
|
||||
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
|
||||
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]:
|
||||
target_node_cls = None
|
||||
if target_node_type:
|
||||
target_node_cls = node_classes.get(target_node_type)
|
||||
if not target_node_cls:
|
||||
raise Exception(f'Node class not found for node type: {target_node_type}')
|
||||
|
||||
if target_node_cls and issubclass(target_node_cls, IterableNodeMixin):
|
||||
# find iteration/loop sub nodes that have no predecessor node
|
||||
sub_graph_root_node_config = None
|
||||
for root_node_config in root_node_configs:
|
||||
if root_node_config.get('parentId') == target_node_id:
|
||||
# create sub graph
|
||||
sub_graph = Graph.init(
|
||||
root_node_config=root_node_config
|
||||
)
|
||||
|
||||
self._recursively_add_edges(
|
||||
graph=sub_graph,
|
||||
source_node_config=root_node_config,
|
||||
edges_mapping=edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
sub_graph_root_node_config = root_node_config
|
||||
break
|
||||
|
||||
if sub_graph_root_node_config:
|
||||
# create sub graph run condition
|
||||
iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls)
|
||||
sub_graph_run_condition = RunCondition(
|
||||
type='condition',
|
||||
conditions=iterable_node_cls.get_conditions(
|
||||
node_config=target_node_config
|
||||
)
|
||||
)
|
||||
|
||||
# create sub graph
|
||||
sub_graph = Graph.init(
|
||||
root_node_config=sub_graph_root_node_config,
|
||||
run_condition=sub_graph_run_condition
|
||||
)
|
||||
|
||||
self._recursively_add_edges(
|
||||
graph=sub_graph,
|
||||
source_node_config=sub_graph_root_node_config,
|
||||
edges_mapping=edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle'):
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
|
||||
# add edge
|
||||
graph.add_edge(
|
||||
edge_config=edge_config,
|
||||
source_node_config=source_node_config,
|
||||
target_node_config=target_node_config,
|
||||
target_node_sub_graph=sub_graph,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
# recursively add edges
|
@ -1,5 +1,4 @@
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
|
||||
|
||||
def test__init_graph():
|
||||
@ -217,18 +216,17 @@ def test__init_graph():
|
||||
],
|
||||
}
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
graph = workflow_engine_manager._init_graph(
|
||||
workflow_entry = WorkflowEntry()
|
||||
graph = workflow_entry._init_graph(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
assert graph.root_node.id == "1717222650545"
|
||||
assert graph.root_node.source_edge_config is None
|
||||
assert graph.root_node.target_edge_config is not None
|
||||
assert graph.root_node.descendant_node_ids == ["1719481290322"]
|
||||
|
||||
assert graph.graph_nodes.get("1719481290322") is not None
|
||||
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
|
||||
|
||||
assert graph.graph_nodes.get("llm").run_condition_callback is not None
|
||||
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None
|
||||
assert graph.graph_nodes.get("llm").run_condition is not None
|
||||
assert graph.graph_nodes.get("1719481315734").run_condition is not None
|
Loading…
x
Reference in New Issue
Block a user