optimize graph

This commit is contained in:
takatost 2024-07-02 21:53:41 +08:00
parent 8375517ccd
commit 0f19b2a986
23 changed files with 454 additions and 193 deletions

View File

@ -0,0 +1,28 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.nodes.base_node import UserFrom
from models.workflow import WorkflowType
class WorkflowRuntimeState(BaseModel):
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
variable_pool: VariablePool
invoke_from: InvokeFrom
graph: Graph
call_depth: int
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.run_condition import RunCondition
class RunConditionHandler(ABC):
def __init__(self, condition: RunCondition):
self.condition = condition
@abstractmethod
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:return: bool
"""
raise NotImplementedError

View File

@ -0,0 +1,25 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
if not predecessor_node_result.edge_source_handle:
return False
return self.condition.branch_identify == predecessor_node_result.edge_source_handle

View File

@ -0,0 +1,31 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
compare_result, _ = condition_processor.process(
variable_pool=graph_runtime_state.variable_pool,
logical_operator="and",
conditions=self.condition.conditions
)
return compare_result

View File

@ -0,0 +1,19 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
"""
Get condition handler
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(run_condition)
else:
return ConditionRunConditionHandlerHandler(run_condition)

View File

@ -1,20 +1,10 @@
from collections.abc import Callable from typing import Optional
from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.utils.condition.entities import Condition from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.run_condition import RunCondition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
class GraphNode(BaseModel): class GraphNode(BaseModel):
@ -33,9 +23,6 @@ class GraphNode(BaseModel):
run_condition: Optional[RunCondition] = None run_condition: Optional[RunCondition] = None
"""condition to run the node""" """condition to run the node"""
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
node_config: dict node_config: dict
"""original node config""" """original node config"""
@ -48,9 +35,22 @@ class GraphNode(BaseModel):
def add_child(self, node_id: str) -> None: def add_child(self, node_id: str) -> None:
self.descendant_node_ids.append(node_id) self.descendant_node_ids.append(node_id)
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
"""
Get run condition handler
:return: run condition handler
"""
if not self.run_condition:
return None
return ConditionManager.get_condition_handler(
run_condition=self.run_condition
)
class Graph(BaseModel): class Graph(BaseModel):
graph_nodes: dict[str, GraphNode] = {} graph_nodes: dict[str, GraphNode] = Field(default_factory=dict)
"""graph nodes""" """graph nodes"""
root_node: GraphNode root_node: GraphNode
@ -65,8 +65,12 @@ class Graph(BaseModel):
:param run_condition: run condition when root node parent is iteration/loop :param run_condition: run condition when root node parent is iteration/loop
:return: graph :return: graph
""" """
node_id = root_node_config.get('id')
if not node_id:
raise ValueError("Graph root node id is required")
root_node = GraphNode( root_node = GraphNode(
id=root_node_config.get('id'), id=node_id,
parent_id=root_node_config.get('parentId'), parent_id=root_node_config.get('parentId'),
node_config=root_node_config, node_config=root_node_config,
run_condition=run_condition run_condition=run_condition
@ -74,15 +78,14 @@ class Graph(BaseModel):
graph = cls(root_node=root_node) graph = cls(root_node=root_node)
# TODO parse run_condition to run_condition_callback graph._add_graph_node(graph.root_node)
graph.add_graph_node(graph.root_node)
return graph return graph
def add_edge(self, edge_config: dict, def add_edge(self, edge_config: dict,
source_node_config: dict, source_node_config: dict,
target_node_config: dict, target_node_config: dict,
target_node_sub_graph: Optional["Graph"] = None) -> None: target_node_sub_graph: Optional["Graph"] = None,
run_condition: Optional[RunCondition] = None) -> None:
""" """
Add edge to the graph Add edge to the graph
@ -90,6 +93,7 @@ class Graph(BaseModel):
:param source_node_config: source node config :param source_node_config: source node config
:param target_node_config: target node config :param target_node_config: target node config
:param target_node_sub_graph: sub graph :param target_node_sub_graph: sub graph
:param run_condition: run condition
""" """
source_node_id = source_node_config.get('id') source_node_id = source_node_config.get('id')
if not source_node_id: if not source_node_id:
@ -105,48 +109,25 @@ class Graph(BaseModel):
source_node = self.graph_nodes[source_node_id] source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id) source_node.add_child(target_node_id)
# if run_conditions:
# run_condition_callback = lambda: all()
if target_node_id not in self.graph_nodes: if target_node_id not in self.graph_nodes:
run_condition = None # todo
run_condition_callback = None # todo
target_graph_node = GraphNode( target_graph_node = GraphNode(
id=target_node_id, id=target_node_id,
parent_id=source_node_config.get('parentId'), parent_id=source_node_config.get('parentId'),
predecessor_node_id=source_node_id, predecessor_node_id=source_node_id,
node_config=target_node_config, node_config=target_node_config,
run_condition=run_condition, run_condition=run_condition,
run_condition_callback=run_condition_callback,
source_edge_config=edge_config, source_edge_config=edge_config,
sub_graph=target_node_sub_graph sub_graph=target_node_sub_graph
) )
self.add_graph_node(target_graph_node) self._add_graph_node(target_graph_node)
else: else:
target_node = self.graph_nodes[target_node_id] target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id target_node.predecessor_node_id = source_node_id
target_node.run_conditions = run_conditions target_node.run_condition = run_condition
target_node.run_condition_callback = run_condition_callback
target_node.source_edge_config = edge_config target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph target_node.sub_graph = target_node_sub_graph
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def get_root_node(self) -> Optional[GraphNode]: def get_root_node(self) -> Optional[GraphNode]:
""" """
Get root node of the graph Get root node of the graph
@ -169,14 +150,28 @@ class Graph(BaseModel):
if not graph_node.descendant_node_ids: if not graph_node.descendant_node_ids:
return None return None
descendants_graph = Graph() descendants_graph = Graph(root_node=graph_node)
descendants_graph.add_graph_node(graph_node) descendants_graph._add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids: for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id) self._add_descendants_graph_nodes(descendants_graph, child_node_id)
return descendants_graph return descendants_graph
def _add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None: def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
""" """
Add descendants graph nodes Add descendants graph nodes
@ -188,7 +183,7 @@ class Graph(BaseModel):
return return
graph_node = self.graph_nodes[node_id] graph_node = self.graph_nodes[node_id]
descendants_graph.add_graph_node(graph_node) descendants_graph._add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids: for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id) self._add_descendants_graph_nodes(descendants_graph, child_node_id)

View File

@ -0,0 +1,26 @@
from typing import Optional
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel):
tenant_id: str
app_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
call_depth: int
graph: Graph
variable_pool: VariablePool
start_at: Optional[float] = None
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -0,0 +1,16 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""

View File

@ -1,55 +1,11 @@
import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
from core.workflow.graph import Graph, GraphNode from models.workflow import WorkflowNodeExecutionStatus
from core.workflow.nodes.base_node import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
"""graph node"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.PENDING
"""node status"""
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""
class RuntimeGraph(BaseModel): class RuntimeGraph(BaseModel):
@ -80,22 +36,3 @@ class RuntimeGraph(BaseModel):
runtime_node.status = RuntimeNode.Status.PAUSED runtime_node.status = RuntimeNode.Status.PAUSED
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.paused_by = paused_by runtime_node.paused_by = paused_by
class WorkflowRuntimeState(BaseModel):
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
variable_pool: VariablePool
invoke_from: InvokeFrom
graph: Graph
call_depth: int
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.graph import GraphNode
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
"""graph node"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.PENDING
"""node status"""
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""

View File

@ -0,0 +1,45 @@
import time
from collections.abc import Generator
from typing import cast
from flask import current_app
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom
class GraphEngine:
def __init__(self, tenant_id: str,
app_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
self.graph_runtime_state = GraphRuntimeState(
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
variable_pool=variable_pool
)
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
self.max_execution_steps = cast(int, max_execution_steps)
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
self.max_execution_time = cast(int, max_execution_time)
self.callbacks = callbacks
def run(self) -> Generator:
self.graph_runtime_state.start_at = time.perf_counter()
pass

View File

@ -7,6 +7,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.iterable_node import IterableNodeMixin
class UserFrom(Enum): class UserFrom(Enum):
@ -39,7 +40,7 @@ class BaseNode(ABC):
user_id: str user_id: str
user_from: UserFrom user_from: UserFrom
invoke_from: InvokeFrom invoke_from: InvokeFrom
workflow_call_depth: int workflow_call_depth: int
node_id: str node_id: str
@ -149,7 +150,8 @@ class BaseNode(ABC):
""" """
return self._node_type return self._node_type
class BaseIterationNode(BaseNode):
class BaseIterationNode(BaseNode, IterableNodeMixin):
@abstractmethod @abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState: def _run(self, variable_pool: VariablePool) -> BaseIterationState:
""" """
@ -174,7 +176,7 @@ class BaseIterationNode(BaseNode):
:return: next node id :return: next node id
""" """
return self._get_next_iteration(variable_pool, state) return self._get_next_iteration(variable_pool, state)
@abstractmethod @abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
""" """

View File

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any
from core.workflow.utils.condition.entities import Condition
class IterableNodeMixin(ABC):
@classmethod
@abstractmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
raise NotImplementedError

View File

@ -1,4 +1,4 @@
from typing import cast from typing import Any, cast
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState from core.workflow.entities.base_node_data_entities import BaseIterationState
@ -6,6 +6,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -116,4 +117,20 @@ class IterationNode(BaseIterationNode):
""" """
return { return {
'input_selector': node_data.iterator_selector, 'input_selector': node_data.iterator_selector,
} }
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=node_config.get('data', {}).get('iterator_selector')
)]

View File

@ -1,7 +1,10 @@
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
class LoopNode(BaseIterationNode): class LoopNode(BaseIterationNode):
@ -18,3 +21,21 @@ class LoopNode(BaseIterationNode):
""" """
Get next iteration start node id based on the graph. Get next iteration start node id based on the graph.
""" """
pass
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
# TODO waiting for implementation
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=[]
)]

View File

@ -14,4 +14,6 @@ class Condition(BaseModel):
# for number # for number
"=", "", ">", "<", "", "", "null", "not null" "=", "", ">", "<", "", "", "null", "not null"
] ]
value_type: Literal["string", "value_selector"] = "string"
value: Optional[str] = None value: Optional[str] = None
value_selector: Optional[list[str]] = None

View File

@ -1,22 +0,0 @@
from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
from core.workflow.graph import GraphNode
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
graph_node: GraphNode,
# TODO cycle_state optional
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
if not graph_node.source_edge_config:
return False
if not graph_node.source_edge_config.get('sourceHandle'):
return True
source_handle = predecessor_node_run_result.edge_source_handle \
if predecessor_node_run_result else None
return (source_handle is not None
and graph_node.source_edge_config.get('sourceHandle') == source_handle)

View File

@ -24,7 +24,12 @@ class ConditionProcessor:
variable_selector=condition.variable_selector variable_selector=condition.variable_selector
) )
expected_value = condition.value if condition.value_type == "value_selector":
expected_value = variable_pool.get_variable_value(
variable_selector=condition.value_selector
)
else:
expected_value = condition.value
input_conditions.append({ input_conditions.append({
"actual_value": actual_value, "actual_value": actual_value,
@ -208,7 +213,7 @@ class ConditionProcessor:
return True return True
return False return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
""" """
Assert equal Assert equal
:param actual_value: actual value :param actual_value: actual value
@ -230,7 +235,7 @@ class ConditionProcessor:
return False return False
return True return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
""" """
Assert not equal Assert not equal
:param actual_value: actual value :param actual_value: actual value
@ -252,7 +257,7 @@ class ConditionProcessor:
return False return False
return True return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
""" """
Assert greater than Assert greater than
:param actual_value: actual value :param actual_value: actual value
@ -274,7 +279,7 @@ class ConditionProcessor:
return False return False
return True return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
""" """
Assert less than Assert less than
:param actual_value: actual value :param actual_value: actual value
@ -296,7 +301,7 @@ class ConditionProcessor:
return False return False
return True return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
""" """
Assert greater than or equal Assert greater than or equal
:param actual_value: actual value :param actual_value: actual value
@ -318,7 +323,7 @@ class ConditionProcessor:
return False return False
return True return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
""" """
Assert less than or equal Assert less than or equal
:param actual_value: actual value :param actual_value: actual value

View File

@ -1,5 +1,6 @@
import logging import logging
import time import time
from collections.abc import Generator
from typing import Any, Optional, cast from typing import Any, Optional, cast
from flask import current_app from flask import current_app
@ -12,15 +13,18 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iterable_node import IterableNodeMixin
from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
@ -36,7 +40,6 @@ from extensions.ext_database import db
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowType,
) )
node_classes = { node_classes = {
@ -60,7 +63,7 @@ node_classes = {
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WorkflowEngineManager: class WorkflowEntry:
def run_workflow(self, workflow: Workflow, def run_workflow(self, workflow: Workflow,
user_id: str, user_id: str,
user_from: UserFrom, user_from: UserFrom,
@ -69,7 +72,7 @@ class WorkflowEngineManager:
user_inputs: dict, user_inputs: dict,
system_inputs: dict[SystemVariable, Any], system_inputs: dict[SystemVariable, Any],
call_depth: int = 0, call_depth: int = 0,
variable_pool: Optional[VariablePool] = None) -> None: variable_pool: Optional[VariablePool] = None) -> Generator:
""" """
:param workflow: Workflow instance :param workflow: Workflow instance
:param user_id: user id :param user_id: user id
@ -102,25 +105,25 @@ class WorkflowEngineManager:
user_inputs=user_inputs user_inputs=user_inputs
) )
# fetch max call depth # init graph
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH") graph = self._init_graph(
workflow_call_max_depth = cast(int, workflow_call_max_depth) graph_config=graph_config
if call_depth > workflow_call_max_depth: )
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init workflow runtime state if not graph:
workflow_runtime_state = WorkflowRuntimeState( raise ValueError('graph not found in workflow')
# init workflow run state
graph_engine = GraphEngine(
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
app_id=workflow.app_id, app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
user_id=user_id, user_id=user_id,
user_from=user_from, user_from=user_from,
variable_pool=variable_pool,
invoke_from=invoke_from, invoke_from=invoke_from,
graph=graph_config,
call_depth=call_depth, call_depth=call_depth,
start_at=time.perf_counter() graph=graph,
variable_pool=variable_pool,
callbacks=callbacks
) )
# init workflow run # init workflow run
@ -130,11 +133,7 @@ class WorkflowEngineManager:
try: try:
# run workflow # run workflow
self._run_workflow( rst = graph_engine.run()
graph_config=graph_config,
workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks,
)
except WorkflowRunFailedError as e: except WorkflowRunFailedError as e:
self._workflow_run_failed( self._workflow_run_failed(
error=e.error, error=e.error,
@ -151,6 +150,8 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
return rst
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]: def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
""" """
Initialize graph Initialize graph
@ -259,30 +260,59 @@ class WorkflowEngineManager:
sub_graph: Optional[Graph] = None sub_graph: Optional[Graph] = None
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type')) target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]: target_node_cls = None
if target_node_type:
target_node_cls = node_classes.get(target_node_type)
if not target_node_cls:
raise Exception(f'Node class not found for node type: {target_node_type}')
if target_node_cls and issubclass(target_node_cls, IterableNodeMixin):
# find iteration/loop sub nodes that have no predecessor node # find iteration/loop sub nodes that have no predecessor node
sub_graph_root_node_config = None
for root_node_config in root_node_configs: for root_node_config in root_node_configs:
if root_node_config.get('parentId') == target_node_id: if root_node_config.get('parentId') == target_node_id:
# create sub graph sub_graph_root_node_config = root_node_config
sub_graph = Graph.init(
root_node_config=root_node_config
)
self._recursively_add_edges(
graph=sub_graph,
source_node_config=root_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
break break
if sub_graph_root_node_config:
# create sub graph run condition
iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls)
sub_graph_run_condition = RunCondition(
type='condition',
conditions=iterable_node_cls.get_conditions(
node_config=target_node_config
)
)
# create sub graph
sub_graph = Graph.init(
root_node_config=sub_graph_root_node_config,
run_condition=sub_graph_run_condition
)
self._recursively_add_edges(
graph=sub_graph,
source_node_config=sub_graph_root_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle'):
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
# add edge # add edge
graph.add_edge( graph.add_edge(
edge_config=edge_config, edge_config=edge_config,
source_node_config=source_node_config, source_node_config=source_node_config,
target_node_config=target_node_config, target_node_config=target_node_config,
target_node_sub_graph=sub_graph, target_node_sub_graph=sub_graph,
run_condition=run_condition
) )
# recursively add edges # recursively add edges

View File

@ -1,5 +1,4 @@
from core.workflow.graph import Graph from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_engine_manager import WorkflowEngineManager
def test__init_graph(): def test__init_graph():
@ -217,18 +216,17 @@ def test__init_graph():
], ],
} }
workflow_engine_manager = WorkflowEngineManager() workflow_entry = WorkflowEntry()
graph = workflow_engine_manager._init_graph( graph = workflow_entry._init_graph(
graph_config=graph_config graph_config=graph_config
) )
assert graph.root_node.id == "1717222650545" assert graph.root_node.id == "1717222650545"
assert graph.root_node.source_edge_config is None assert graph.root_node.source_edge_config is None
assert graph.root_node.target_edge_config is not None
assert graph.root_node.descendant_node_ids == ["1719481290322"] assert graph.root_node.descendant_node_ids == ["1719481290322"]
assert graph.graph_nodes.get("1719481290322") is not None assert graph.graph_nodes.get("1719481290322") is not None
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2 assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
assert graph.graph_nodes.get("llm").run_condition_callback is not None assert graph.graph_nodes.get("llm").run_condition is not None
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None assert graph.graph_nodes.get("1719481315734").run_condition is not None