From 0f19b2a98648884f61464c1645d303f80c2a30c1 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 2 Jul 2024 21:53:41 +0800 Subject: [PATCH] optimize graph --- .../entities/workflow_runtime_state.py | 28 +++++ api/core/workflow/graph_engine/__init__.py | 0 .../condition_handlers/__init__.py | 0 .../condition_handlers/base_handler.py | 24 +++++ .../branch_identify_handler.py | 25 +++++ .../condition_handlers/condition_handler.py | 31 ++++++ .../condition_handlers/condition_manager.py | 19 ++++ .../graph_engine/entities/__init__.py | 0 .../{ => graph_engine/entities}/graph.py | 97 ++++++++--------- .../entities/graph_runtime_state.py | 26 +++++ .../graph_engine/entities/run_condition.py | 16 +++ .../entities/runtime_graph.py} | 69 +----------- .../graph_engine/entities/runtime_node.py | 48 +++++++++ .../workflow/graph_engine/graph_engine.py | 45 ++++++++ api/core/workflow/nodes/base_node.py | 8 +- api/core/workflow/nodes/iterable_node.py | 14 +++ .../nodes/iteration/iteration_node.py | 21 +++- api/core/workflow/nodes/loop/loop_node.py | 21 ++++ api/core/workflow/utils/condition/entities.py | 2 + api/core/workflow/utils/condition/funcs.py | 22 ---- .../workflow/utils/condition/processor.py | 19 ++-- ...ow_engine_manager.py => workflow_entry.py} | 100 ++++++++++++------ ...gine_manager.py => test_workflow_entry.py} | 12 +-- 23 files changed, 454 insertions(+), 193 deletions(-) create mode 100644 api/core/workflow/entities/workflow_runtime_state.py create mode 100644 api/core/workflow/graph_engine/__init__.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/__init__.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/base_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/condition_handler.py create mode 100644 api/core/workflow/graph_engine/condition_handlers/condition_manager.py create mode 100644 api/core/workflow/graph_engine/entities/__init__.py rename api/core/workflow/{ => graph_engine/entities}/graph.py (75%) create mode 100644 api/core/workflow/graph_engine/entities/graph_runtime_state.py create mode 100644 api/core/workflow/graph_engine/entities/run_condition.py rename api/core/workflow/{entities/workflow_runtime_state_entities.py => graph_engine/entities/runtime_graph.py} (50%) create mode 100644 api/core/workflow/graph_engine/entities/runtime_node.py create mode 100644 api/core/workflow/graph_engine/graph_engine.py create mode 100644 api/core/workflow/nodes/iterable_node.py delete mode 100644 api/core/workflow/utils/condition/funcs.py rename api/core/workflow/{workflow_engine_manager.py => workflow_entry.py} (94%) rename api/tests/unit_tests/core/workflow/{test_workflow_engine_manager.py => test_workflow_entry.py} (94%) diff --git a/api/core/workflow/entities/workflow_runtime_state.py b/api/core/workflow/entities/workflow_runtime_state.py new file mode 100644 index 0000000000..1ef9bf1571 --- /dev/null +++ b/api/core/workflow/entities/workflow_runtime_state.py @@ -0,0 +1,28 @@ + +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph +from core.workflow.nodes.base_node import UserFrom +from models.workflow import WorkflowType + + +class WorkflowRuntimeState(BaseModel): + tenant_id: str + app_id: str + workflow_id: str + workflow_type: WorkflowType + user_id: str + user_from: UserFrom + variable_pool: VariablePool + invoke_from: InvokeFrom + graph: Graph + call_depth: int + start_at: float + + total_tokens: int = 0 + node_run_steps: int = 0 + + runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph) diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/condition_handlers/__init__.py b/api/core/workflow/graph_engine/condition_handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py new file mode 100644 index 0000000000..5b96f280dc --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.graph_engine.entities.run_condition import RunCondition + + +class RunConditionHandler(ABC): + def __init__(self, condition: RunCondition): + self.condition = condition + + @abstractmethod + def check(self, + graph_node: "GraphNode", + graph_runtime_state: "GraphRuntimeState", + predecessor_node_result: NodeRunResult) -> bool: + """ + Check if the condition can be executed + + :param graph_node: graph node + :param graph_runtime_state: graph runtime state + :param predecessor_node_result: predecessor node result + :return: bool + """ + raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py new file mode 100644 index 0000000000..90cc035a4f --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -0,0 +1,25 @@ +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler + + +class BranchIdentifyRunConditionHandler(RunConditionHandler): + + def check(self, + graph_node: "GraphNode", + graph_runtime_state: "GraphRuntimeState", + predecessor_node_result: NodeRunResult) -> bool: + """ + Check if the condition can be executed + + :param graph_node: graph node + :param graph_runtime_state: graph runtime state + :param predecessor_node_result: predecessor node result + :return: bool + """ + if not self.condition.branch_identify: + raise Exception("Branch identify is required") + + if not predecessor_node_result.edge_source_handle: + return False + + return self.condition.branch_identify == predecessor_node_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py new file mode 100644 index 0000000000..cca7fc235e --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -0,0 +1,31 @@ +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.utils.condition.processor import ConditionProcessor + + +class ConditionRunConditionHandlerHandler(RunConditionHandler): + def check(self, + graph_node: "GraphNode", + graph_runtime_state: "GraphRuntimeState", + predecessor_node_result: NodeRunResult) -> bool: + """ + Check if the condition can be executed + + :param graph_node: graph node + :param graph_runtime_state: graph runtime state + :param predecessor_node_result: predecessor node result + :return: bool + """ + if not self.condition.conditions: + return True + + # process condition + condition_processor = ConditionProcessor() + compare_result, _ = condition_processor.process( + variable_pool=graph_runtime_state.variable_pool, + logical_operator="and", + conditions=self.condition.conditions + ) + + return compare_result + diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py new file mode 100644 index 0000000000..5b3c430418 --- /dev/null +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -0,0 +1,19 @@ +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler +from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler +from core.workflow.graph_engine.entities.run_condition import RunCondition + + +class ConditionManager: + @staticmethod + def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler: + """ + Get condition handler + + :param run_condition: run condition + :return: condition handler + """ + if run_condition.type == "branch_identify": + return BranchIdentifyRunConditionHandler(run_condition) + else: + return ConditionRunConditionHandlerHandler(run_condition) diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/core/workflow/graph_engine/entities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/graph.py b/api/core/workflow/graph_engine/entities/graph.py similarity index 75% rename from api/core/workflow/graph.py rename to api/core/workflow/graph_engine/entities/graph.py index be981c2ea8..7897749687 100644 --- a/api/core/workflow/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -1,20 +1,10 @@ -from collections.abc import Callable -from typing import Literal, Optional +from typing import Optional from pydantic import BaseModel, Field -from core.workflow.utils.condition.entities import Condition - - -class RunCondition(BaseModel): - type: Literal["branch_identify", "condition"] - """condition type""" - - branch_identify: Optional[str] = None - """branch identify, required when type is branch_identify""" - - conditions: Optional[list[Condition]] = None - """conditions to run the node, required when type is condition""" +from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.graph_engine.entities.run_condition import RunCondition class GraphNode(BaseModel): @@ -33,9 +23,6 @@ class GraphNode(BaseModel): run_condition: Optional[RunCondition] = None """condition to run the node""" - run_condition_callback: Optional[Callable] = Field(None, exclude=True) - """condition function check if the node can be executed, translated from run_conditions, not serialized""" - node_config: dict """original node config""" @@ -48,9 +35,22 @@ class GraphNode(BaseModel): def add_child(self, node_id: str) -> None: self.descendant_node_ids.append(node_id) + def get_run_condition_handler(self) -> Optional[RunConditionHandler]: + """ + Get run condition handler + + :return: run condition handler + """ + if not self.run_condition: + return None + + return ConditionManager.get_condition_handler( + run_condition=self.run_condition + ) + class Graph(BaseModel): - graph_nodes: dict[str, GraphNode] = {} + graph_nodes: dict[str, GraphNode] = Field(default_factory=dict) """graph nodes""" root_node: GraphNode @@ -65,8 +65,12 @@ class Graph(BaseModel): :param run_condition: run condition when root node parent is iteration/loop :return: graph """ + node_id = root_node_config.get('id') + if not node_id: + raise ValueError("Graph root node id is required") + root_node = GraphNode( - id=root_node_config.get('id'), + id=node_id, parent_id=root_node_config.get('parentId'), node_config=root_node_config, run_condition=run_condition @@ -74,15 +78,14 @@ class Graph(BaseModel): graph = cls(root_node=root_node) - # TODO parse run_condition to run_condition_callback - - graph.add_graph_node(graph.root_node) + graph._add_graph_node(graph.root_node) return graph def add_edge(self, edge_config: dict, source_node_config: dict, target_node_config: dict, - target_node_sub_graph: Optional["Graph"] = None) -> None: + target_node_sub_graph: Optional["Graph"] = None, + run_condition: Optional[RunCondition] = None) -> None: """ Add edge to the graph @@ -90,6 +93,7 @@ class Graph(BaseModel): :param source_node_config: source node config :param target_node_config: target node config :param target_node_sub_graph: sub graph + :param run_condition: run condition """ source_node_id = source_node_config.get('id') if not source_node_id: @@ -105,48 +109,25 @@ class Graph(BaseModel): source_node = self.graph_nodes[source_node_id] source_node.add_child(target_node_id) - # if run_conditions: - # run_condition_callback = lambda: all() - - if target_node_id not in self.graph_nodes: - run_condition = None # todo - run_condition_callback = None # todo - target_graph_node = GraphNode( id=target_node_id, parent_id=source_node_config.get('parentId'), predecessor_node_id=source_node_id, node_config=target_node_config, run_condition=run_condition, - run_condition_callback=run_condition_callback, source_edge_config=edge_config, sub_graph=target_node_sub_graph ) - self.add_graph_node(target_graph_node) + self._add_graph_node(target_graph_node) else: target_node = self.graph_nodes[target_node_id] target_node.predecessor_node_id = source_node_id - target_node.run_conditions = run_conditions - target_node.run_condition_callback = run_condition_callback + target_node.run_condition = run_condition target_node.source_edge_config = edge_config target_node.sub_graph = target_node_sub_graph - def add_graph_node(self, graph_node: GraphNode) -> None: - """ - Add graph node to the graph - - :param graph_node: graph node - """ - if graph_node.id in self.graph_nodes: - return - - if len(self.graph_nodes) == 0: - self.root_node = graph_node - - self.graph_nodes[graph_node.id] = graph_node - def get_root_node(self) -> Optional[GraphNode]: """ Get root node of the graph @@ -169,14 +150,28 @@ class Graph(BaseModel): if not graph_node.descendant_node_ids: return None - descendants_graph = Graph() - descendants_graph.add_graph_node(graph_node) + descendants_graph = Graph(root_node=graph_node) + descendants_graph._add_graph_node(graph_node) for child_node_id in graph_node.descendant_node_ids: self._add_descendants_graph_nodes(descendants_graph, child_node_id) return descendants_graph + def _add_graph_node(self, graph_node: GraphNode) -> None: + """ + Add graph node to the graph + + :param graph_node: graph node + """ + if graph_node.id in self.graph_nodes: + return + + if len(self.graph_nodes) == 0: + self.root_node = graph_node + + self.graph_nodes[graph_node.id] = graph_node + def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None: """ Add descendants graph nodes @@ -188,7 +183,7 @@ class Graph(BaseModel): return graph_node = self.graph_nodes[node_id] - descendants_graph.add_graph_node(graph_node) + descendants_graph._add_graph_node(graph_node) for child_node_id in graph_node.descendant_node_ids: self._add_descendants_graph_nodes(descendants_graph, child_node_id) diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py new file mode 100644 index 0000000000..8eafb8869e --- /dev/null +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -0,0 +1,26 @@ +from typing import Optional + +from pydantic import BaseModel, Field + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph +from core.workflow.nodes.base_node import UserFrom + + +class GraphRuntimeState(BaseModel): + tenant_id: str + app_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + call_depth: int + + graph: Graph + variable_pool: VariablePool + + start_at: Optional[float] = None + total_tokens: int = 0 + node_run_steps: int = 0 + runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph) diff --git a/api/core/workflow/graph_engine/entities/run_condition.py b/api/core/workflow/graph_engine/entities/run_condition.py new file mode 100644 index 0000000000..46fca1b032 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/run_condition.py @@ -0,0 +1,16 @@ +from typing import Literal, Optional + +from pydantic import BaseModel + +from core.workflow.utils.condition.entities import Condition + + +class RunCondition(BaseModel): + type: Literal["branch_identify", "condition"] + """condition type""" + + branch_identify: Optional[str] = None + """branch identify like: sourceHandle, required when type is branch_identify""" + + conditions: Optional[list[Condition]] = None + """conditions to run the node, required when type is condition""" diff --git a/api/core/workflow/entities/workflow_runtime_state_entities.py b/api/core/workflow/graph_engine/entities/runtime_graph.py similarity index 50% rename from api/core/workflow/entities/workflow_runtime_state_entities.py rename to api/core/workflow/graph_engine/entities/runtime_graph.py index cee39a1a99..916a5a983f 100644 --- a/api/core/workflow/entities/workflow_runtime_state_entities.py +++ b/api/core/workflow/graph_engine/entities/runtime_graph.py @@ -1,55 +1,11 @@ -import uuid from datetime import datetime, timezone -from enum import Enum from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph import Graph, GraphNode -from core.workflow.nodes.base_node import UserFrom -from models.workflow import WorkflowNodeExecutionStatus, WorkflowType - - -class RuntimeNode(BaseModel): - class Status(Enum): - PENDING = "pending" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - PAUSED = "paused" - - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - """random id for current runtime node""" - - graph_node: GraphNode - """graph node""" - - node_run_result: Optional[NodeRunResult] = None - """node run result""" - - status: Status = Status.PENDING - """node status""" - - start_at: Optional[datetime] = None - """start time""" - - paused_at: Optional[datetime] = None - """paused time""" - - finished_at: Optional[datetime] = None - """finished time""" - - failed_reason: Optional[str] = None - """failed reason""" - - paused_by: Optional[str] = None - """paused by""" - - predecessor_runtime_node_id: Optional[str] = None - """predecessor runtime node id""" +from core.workflow.graph_engine.entities.runtime_node import RuntimeNode +from models.workflow import WorkflowNodeExecutionStatus class RuntimeGraph(BaseModel): @@ -80,22 +36,3 @@ class RuntimeGraph(BaseModel): runtime_node.status = RuntimeNode.Status.PAUSED runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) runtime_node.paused_by = paused_by - - -class WorkflowRuntimeState(BaseModel): - tenant_id: str - app_id: str - workflow_id: str - workflow_type: WorkflowType - user_id: str - user_from: UserFrom - variable_pool: VariablePool - invoke_from: InvokeFrom - graph: Graph - call_depth: int - start_at: float - - total_tokens: int = 0 - node_run_steps: int = 0 - - runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph) diff --git a/api/core/workflow/graph_engine/entities/runtime_node.py b/api/core/workflow/graph_engine/entities/runtime_node.py new file mode 100644 index 0000000000..2ba894fa13 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/runtime_node.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.graph_engine.entities.graph import GraphNode + + +class RuntimeNode(BaseModel): + class Status(Enum): + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + PAUSED = "paused" + + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """random id for current runtime node""" + + graph_node: GraphNode + """graph node""" + + node_run_result: Optional[NodeRunResult] = None + """node run result""" + + status: Status = Status.PENDING + """node status""" + + start_at: Optional[datetime] = None + """start time""" + + paused_at: Optional[datetime] = None + """paused time""" + + finished_at: Optional[datetime] = None + """finished time""" + + failed_reason: Optional[str] = None + """failed reason""" + + paused_by: Optional[str] = None + """paused by""" + + predecessor_runtime_node_id: Optional[str] = None + """predecessor runtime node id""" diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py new file mode 100644 index 0000000000..0851bec603 --- /dev/null +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -0,0 +1,45 @@ +import time +from collections.abc import Generator +from typing import cast + +from flask import current_app + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.nodes.base_node import UserFrom + + +class GraphEngine: + def __init__(self, tenant_id: str, + app_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + call_depth: int, + graph: Graph, + variable_pool: VariablePool, + callbacks: list[BaseWorkflowCallback]) -> None: + self.graph_runtime_state = GraphRuntimeState( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + call_depth=call_depth, + graph=graph, + variable_pool=variable_pool + ) + + max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") + self.max_execution_steps = cast(int, max_execution_steps) + max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") + self.max_execution_time = cast(int, max_execution_time) + + self.callbacks = callbacks + + def run(self) -> Generator: + self.graph_runtime_state.start_at = time.perf_counter() + pass diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index fa7d6424f1..ed2bb70711 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -7,6 +7,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.iterable_node import IterableNodeMixin class UserFrom(Enum): @@ -39,7 +40,7 @@ class BaseNode(ABC): user_id: str user_from: UserFrom invoke_from: InvokeFrom - + workflow_call_depth: int node_id: str @@ -149,7 +150,8 @@ class BaseNode(ABC): """ return self._node_type -class BaseIterationNode(BaseNode): + +class BaseIterationNode(BaseNode, IterableNodeMixin): @abstractmethod def _run(self, variable_pool: VariablePool) -> BaseIterationState: """ @@ -174,7 +176,7 @@ class BaseIterationNode(BaseNode): :return: next node id """ return self._get_next_iteration(variable_pool, state) - + @abstractmethod def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: """ diff --git a/api/core/workflow/nodes/iterable_node.py b/api/core/workflow/nodes/iterable_node.py new file mode 100644 index 0000000000..e45ba7df60 --- /dev/null +++ b/api/core/workflow/nodes/iterable_node.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from typing import Any + +from core.workflow.utils.condition.entities import Condition + + +class IterableNodeMixin(ABC): + @classmethod + @abstractmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: + """ + Get conditions. + """ + raise NotImplementedError diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 12d792f297..fb12e07f85 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.base_node_data_entities import BaseIterationState @@ -6,6 +6,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseIterationNode from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState +from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowNodeExecutionStatus @@ -116,4 +117,20 @@ class IterationNode(BaseIterationNode): """ return { 'input_selector': node_data.iterator_selector, - } \ No newline at end of file + } + + @classmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: + """ + Get conditions. + """ + node_id = node_config.get('id') + if not node_id: + return [] + + return [Condition( + variable_selector=[node_id, 'index'], + comparison_operator="≤", + value_type="value_selector", + value_selector=node_config.get('data', {}).get('iterator_selector') + )] diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 7d53c6f5f2..4b1421cae9 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -1,7 +1,10 @@ +from typing import Any + from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseIterationNode from core.workflow.nodes.loop.entities import LoopNodeData, LoopState +from core.workflow.utils.condition.entities import Condition class LoopNode(BaseIterationNode): @@ -18,3 +21,21 @@ class LoopNode(BaseIterationNode): """ Get next iteration start node id based on the graph. """ + pass + + @classmethod + def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]: + """ + Get conditions. + """ + node_id = node_config.get('id') + if not node_id: + return [] + + # TODO waiting for implementation + return [Condition( + variable_selector=[node_id, 'index'], + comparison_operator="≤", + value_type="value_selector", + value_selector=[] + )] diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index e195730a31..524cce1a43 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -14,4 +14,6 @@ class Condition(BaseModel): # for number "=", "≠", ">", "<", "≥", "≤", "null", "not null" ] + value_type: Literal["string", "value_selector"] = "string" value: Optional[str] = None + value_selector: Optional[list[str]] = None diff --git a/api/core/workflow/utils/condition/funcs.py b/api/core/workflow/utils/condition/funcs.py deleted file mode 100644 index fa2f3ca9f9..0000000000 --- a/api/core/workflow/utils/condition/funcs.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Optional - -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState -from core.workflow.graph import GraphNode - - -def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState, - graph_node: GraphNode, - # TODO cycle_state optional - predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool: - if not graph_node.source_edge_config: - return False - - if not graph_node.source_edge_config.get('sourceHandle'): - return True - - source_handle = predecessor_node_run_result.edge_source_handle \ - if predecessor_node_run_result else None - - return (source_handle is not None - and graph_node.source_edge_config.get('sourceHandle') == source_handle) diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index 405b01b9fa..cebd234570 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -24,7 +24,12 @@ class ConditionProcessor: variable_selector=condition.variable_selector ) - expected_value = condition.value + if condition.value_type == "value_selector": + expected_value = variable_pool.get_variable_value( + variable_selector=condition.value_selector + ) + else: + expected_value = condition.value input_conditions.append({ "actual_value": actual_value, @@ -208,7 +213,7 @@ class ConditionProcessor: return True return False - def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: """ Assert equal :param actual_value: actual value @@ -230,7 +235,7 @@ class ConditionProcessor: return False return True - def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: """ Assert not equal :param actual_value: actual value @@ -252,7 +257,7 @@ class ConditionProcessor: return False return True - def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: """ Assert greater than :param actual_value: actual value @@ -274,7 +279,7 @@ class ConditionProcessor: return False return True - def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool: + def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: """ Assert less than :param actual_value: actual value @@ -296,7 +301,7 @@ class ConditionProcessor: return False return True - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: """ Assert greater than or equal :param actual_value: actual value @@ -318,7 +323,7 @@ class ConditionProcessor: return False return True - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool: + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: """ Assert less than or equal :param actual_value: actual value diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_entry.py similarity index 94% rename from api/core/workflow/workflow_engine_manager.py rename to api/core/workflow/workflow_entry.py index 66c081860d..03f20421f1 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_entry.py @@ -1,5 +1,6 @@ import logging import time +from collections.abc import Generator from typing import Any, Optional, cast from flask import current_app @@ -12,15 +13,18 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState -from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState +from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph import Graph +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iterable_node import IterableNodeMixin from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode @@ -36,7 +40,6 @@ from extensions.ext_database import db from models.workflow import ( Workflow, WorkflowNodeExecutionStatus, - WorkflowType, ) node_classes = { @@ -60,7 +63,7 @@ node_classes = { logger = logging.getLogger(__name__) -class WorkflowEngineManager: +class WorkflowEntry: def run_workflow(self, workflow: Workflow, user_id: str, user_from: UserFrom, @@ -69,7 +72,7 @@ class WorkflowEngineManager: user_inputs: dict, system_inputs: dict[SystemVariable, Any], call_depth: int = 0, - variable_pool: Optional[VariablePool] = None) -> None: + variable_pool: Optional[VariablePool] = None) -> Generator: """ :param workflow: Workflow instance :param user_id: user id @@ -102,25 +105,25 @@ class WorkflowEngineManager: user_inputs=user_inputs ) - # fetch max call depth - workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH") - workflow_call_max_depth = cast(int, workflow_call_max_depth) - if call_depth > workflow_call_max_depth: - raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) + # init graph + graph = self._init_graph( + graph_config=graph_config + ) - # init workflow runtime state - workflow_runtime_state = WorkflowRuntimeState( + if not graph: + raise ValueError('graph not found in workflow') + + # init workflow run state + graph_engine = GraphEngine( tenant_id=workflow.tenant_id, app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), user_id=user_id, user_from=user_from, - variable_pool=variable_pool, invoke_from=invoke_from, - graph=graph_config, call_depth=call_depth, - start_at=time.perf_counter() + graph=graph, + variable_pool=variable_pool, + callbacks=callbacks ) # init workflow run @@ -130,11 +133,7 @@ class WorkflowEngineManager: try: # run workflow - self._run_workflow( - graph_config=graph_config, - workflow_runtime_state=workflow_runtime_state, - callbacks=callbacks, - ) + rst = graph_engine.run() except WorkflowRunFailedError as e: self._workflow_run_failed( error=e.error, @@ -151,6 +150,8 @@ class WorkflowEngineManager: callbacks=callbacks ) + return rst + def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]: """ Initialize graph @@ -259,30 +260,59 @@ class WorkflowEngineManager: sub_graph: Optional[Graph] = None target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type')) - if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]: + target_node_cls = None + if target_node_type: + target_node_cls = node_classes.get(target_node_type) + if not target_node_cls: + raise Exception(f'Node class not found for node type: {target_node_type}') + + if target_node_cls and issubclass(target_node_cls, IterableNodeMixin): # find iteration/loop sub nodes that have no predecessor node + sub_graph_root_node_config = None for root_node_config in root_node_configs: if root_node_config.get('parentId') == target_node_id: - # create sub graph - sub_graph = Graph.init( - root_node_config=root_node_config - ) - - self._recursively_add_edges( - graph=sub_graph, - source_node_config=root_node_config, - edges_mapping=edges_mapping, - nodes_mapping=nodes_mapping, - root_node_configs=root_node_configs - ) + sub_graph_root_node_config = root_node_config break + if sub_graph_root_node_config: + # create sub graph run condition + iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls) + sub_graph_run_condition = RunCondition( + type='condition', + conditions=iterable_node_cls.get_conditions( + node_config=target_node_config + ) + ) + + # create sub graph + sub_graph = Graph.init( + root_node_config=sub_graph_root_node_config, + run_condition=sub_graph_run_condition + ) + + self._recursively_add_edges( + graph=sub_graph, + source_node_config=sub_graph_root_node_config, + edges_mapping=edges_mapping, + nodes_mapping=nodes_mapping, + root_node_configs=root_node_configs + ) + + # parse run condition + run_condition = None + if edge_config.get('sourceHandle'): + run_condition = RunCondition( + type='branch_identify', + branch_identify=edge_config.get('sourceHandle') + ) + # add edge graph.add_edge( edge_config=edge_config, source_node_config=source_node_config, target_node_config=target_node_config, target_node_sub_graph=sub_graph, + run_condition=run_condition ) # recursively add edges diff --git a/api/tests/unit_tests/core/workflow/test_workflow_engine_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py similarity index 94% rename from api/tests/unit_tests/core/workflow/test_workflow_engine_manager.py rename to api/tests/unit_tests/core/workflow/test_workflow_entry.py index 0d3ba65843..5b9d0af3cd 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_engine_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -1,5 +1,4 @@ -from core.workflow.graph import Graph -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_entry import WorkflowEntry def test__init_graph(): @@ -217,18 +216,17 @@ def test__init_graph(): ], } - workflow_engine_manager = WorkflowEngineManager() - graph = workflow_engine_manager._init_graph( + workflow_entry = WorkflowEntry() + graph = workflow_entry._init_graph( graph_config=graph_config ) assert graph.root_node.id == "1717222650545" assert graph.root_node.source_edge_config is None - assert graph.root_node.target_edge_config is not None assert graph.root_node.descendant_node_ids == ["1719481290322"] assert graph.graph_nodes.get("1719481290322") is not None assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2 - assert graph.graph_nodes.get("llm").run_condition_callback is not None - assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None + assert graph.graph_nodes.get("llm").run_condition is not None + assert graph.graph_nodes.get("1719481315734").run_condition is not None