optimize graph

This commit is contained in:
takatost 2024-07-02 21:53:41 +08:00
parent 8375517ccd
commit 0f19b2a986
23 changed files with 454 additions and 193 deletions

View File

@ -0,0 +1,28 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.nodes.base_node import UserFrom
from models.workflow import WorkflowType
class WorkflowRuntimeState(BaseModel):
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
variable_pool: VariablePool
invoke_from: InvokeFrom
graph: Graph
call_depth: int
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.run_condition import RunCondition
class RunConditionHandler(ABC):
def __init__(self, condition: RunCondition):
self.condition = condition
@abstractmethod
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:return: bool
"""
raise NotImplementedError

View File

@ -0,0 +1,25 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
if not predecessor_node_result.edge_source_handle:
return False
return self.condition.branch_identify == predecessor_node_result.edge_source_handle

View File

@ -0,0 +1,31 @@
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_node: "GraphNode",
graph_runtime_state: "GraphRuntimeState",
predecessor_node_result: NodeRunResult) -> bool:
"""
Check if the condition can be executed
:param graph_node: graph node
:param graph_runtime_state: graph runtime state
:param predecessor_node_result: predecessor node result
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
compare_result, _ = condition_processor.process(
variable_pool=graph_runtime_state.variable_pool,
logical_operator="and",
conditions=self.condition.conditions
)
return compare_result

View File

@ -0,0 +1,19 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
"""
Get condition handler
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(run_condition)
else:
return ConditionRunConditionHandlerHandler(run_condition)

View File

@ -1,20 +1,10 @@
from collections.abc import Callable
from typing import Literal, Optional
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.run_condition import RunCondition
class GraphNode(BaseModel):
@ -33,9 +23,6 @@ class GraphNode(BaseModel):
run_condition: Optional[RunCondition] = None
"""condition to run the node"""
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
node_config: dict
"""original node config"""
@ -48,9 +35,22 @@ class GraphNode(BaseModel):
def add_child(self, node_id: str) -> None:
self.descendant_node_ids.append(node_id)
def get_run_condition_handler(self) -> Optional[RunConditionHandler]:
"""
Get run condition handler
:return: run condition handler
"""
if not self.run_condition:
return None
return ConditionManager.get_condition_handler(
run_condition=self.run_condition
)
class Graph(BaseModel):
graph_nodes: dict[str, GraphNode] = {}
graph_nodes: dict[str, GraphNode] = Field(default_factory=dict)
"""graph nodes"""
root_node: GraphNode
@ -65,8 +65,12 @@ class Graph(BaseModel):
:param run_condition: run condition when root node parent is iteration/loop
:return: graph
"""
node_id = root_node_config.get('id')
if not node_id:
raise ValueError("Graph root node id is required")
root_node = GraphNode(
id=root_node_config.get('id'),
id=node_id,
parent_id=root_node_config.get('parentId'),
node_config=root_node_config,
run_condition=run_condition
@ -74,15 +78,14 @@ class Graph(BaseModel):
graph = cls(root_node=root_node)
# TODO parse run_condition to run_condition_callback
graph.add_graph_node(graph.root_node)
graph._add_graph_node(graph.root_node)
return graph
def add_edge(self, edge_config: dict,
source_node_config: dict,
target_node_config: dict,
target_node_sub_graph: Optional["Graph"] = None) -> None:
target_node_sub_graph: Optional["Graph"] = None,
run_condition: Optional[RunCondition] = None) -> None:
"""
Add edge to the graph
@ -90,6 +93,7 @@ class Graph(BaseModel):
:param source_node_config: source node config
:param target_node_config: target node config
:param target_node_sub_graph: sub graph
:param run_condition: run condition
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
@ -105,48 +109,25 @@ class Graph(BaseModel):
source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id)
# if run_conditions:
# run_condition_callback = lambda: all()
if target_node_id not in self.graph_nodes:
run_condition = None # todo
run_condition_callback = None # todo
target_graph_node = GraphNode(
id=target_node_id,
parent_id=source_node_config.get('parentId'),
predecessor_node_id=source_node_id,
node_config=target_node_config,
run_condition=run_condition,
run_condition_callback=run_condition_callback,
source_edge_config=edge_config,
sub_graph=target_node_sub_graph
)
self.add_graph_node(target_graph_node)
self._add_graph_node(target_graph_node)
else:
target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id
target_node.run_conditions = run_conditions
target_node.run_condition_callback = run_condition_callback
target_node.run_condition = run_condition
target_node.source_edge_config = edge_config
target_node.sub_graph = target_node_sub_graph
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def get_root_node(self) -> Optional[GraphNode]:
"""
Get root node of the graph
@ -169,14 +150,28 @@ class Graph(BaseModel):
if not graph_node.descendant_node_ids:
return None
descendants_graph = Graph()
descendants_graph.add_graph_node(graph_node)
descendants_graph = Graph(root_node=graph_node)
descendants_graph._add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
return descendants_graph
def _add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
"""
Add descendants graph nodes
@ -188,7 +183,7 @@ class Graph(BaseModel):
return
graph_node = self.graph_nodes[node_id]
descendants_graph.add_graph_node(graph_node)
descendants_graph._add_graph_node(graph_node)
for child_node_id in graph_node.descendant_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)

View File

@ -0,0 +1,26 @@
from typing import Optional
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel):
tenant_id: str
app_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
call_depth: int
graph: Graph
variable_pool: VariablePool
start_at: Optional[float] = None
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -0,0 +1,16 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""

View File

@ -1,55 +1,11 @@
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph import Graph, GraphNode
from core.workflow.nodes.base_node import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
"""graph node"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.PENDING
"""node status"""
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""
from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
from models.workflow import WorkflowNodeExecutionStatus
class RuntimeGraph(BaseModel):
@ -80,22 +36,3 @@ class RuntimeGraph(BaseModel):
runtime_node.status = RuntimeNode.Status.PAUSED
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.paused_by = paused_by
class WorkflowRuntimeState(BaseModel):
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
variable_pool: VariablePool
invoke_from: InvokeFrom
graph: Graph
call_depth: int
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.graph import GraphNode
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
"""graph node"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.PENDING
"""node status"""
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""

View File

@ -0,0 +1,45 @@
import time
from collections.abc import Generator
from typing import cast
from flask import current_app
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom
class GraphEngine:
def __init__(self, tenant_id: str,
app_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
self.graph_runtime_state = GraphRuntimeState(
tenant_id=tenant_id,
app_id=app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
variable_pool=variable_pool
)
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
self.max_execution_steps = cast(int, max_execution_steps)
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
self.max_execution_time = cast(int, max_execution_time)
self.callbacks = callbacks
def run(self) -> Generator:
self.graph_runtime_state.start_at = time.perf_counter()
pass

View File

@ -7,6 +7,7 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.iterable_node import IterableNodeMixin
class UserFrom(Enum):
@ -149,7 +150,8 @@ class BaseNode(ABC):
"""
return self._node_type
class BaseIterationNode(BaseNode):
class BaseIterationNode(BaseNode, IterableNodeMixin):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""

View File

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any
from core.workflow.utils.condition.entities import Condition
class IterableNodeMixin(ABC):
@classmethod
@abstractmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
raise NotImplementedError

View File

@ -1,4 +1,4 @@
from typing import cast
from typing import Any, cast
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
@ -6,6 +6,7 @@ from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus
@ -117,3 +118,19 @@ class IterationNode(BaseIterationNode):
return {
'input_selector': node_data.iterator_selector,
}
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=node_config.get('data', {}).get('iterator_selector')
)]

View File

@ -1,7 +1,10 @@
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
class LoopNode(BaseIterationNode):
@ -18,3 +21,21 @@ class LoopNode(BaseIterationNode):
"""
Get next iteration start node id based on the graph.
"""
pass
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
# TODO waiting for implementation
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=[]
)]

View File

@ -14,4 +14,6 @@ class Condition(BaseModel):
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value_type: Literal["string", "value_selector"] = "string"
value: Optional[str] = None
value_selector: Optional[list[str]] = None

View File

@ -1,22 +0,0 @@
from typing import Optional
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
from core.workflow.graph import GraphNode
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
graph_node: GraphNode,
# TODO cycle_state optional
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
if not graph_node.source_edge_config:
return False
if not graph_node.source_edge_config.get('sourceHandle'):
return True
source_handle = predecessor_node_run_result.edge_source_handle \
if predecessor_node_run_result else None
return (source_handle is not None
and graph_node.source_edge_config.get('sourceHandle') == source_handle)

View File

@ -24,6 +24,11 @@ class ConditionProcessor:
variable_selector=condition.variable_selector
)
if condition.value_type == "value_selector":
expected_value = variable_pool.get_variable_value(
variable_selector=condition.value_selector
)
else:
expected_value = condition.value
input_conditions.append({
@ -208,7 +213,7 @@ class ConditionProcessor:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert equal
:param actual_value: actual value
@ -230,7 +235,7 @@ class ConditionProcessor:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert not equal
:param actual_value: actual value
@ -252,7 +257,7 @@ class ConditionProcessor:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert greater than
:param actual_value: actual value
@ -274,7 +279,7 @@ class ConditionProcessor:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert less than
:param actual_value: actual value
@ -296,7 +301,7 @@ class ConditionProcessor:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
@ -318,7 +323,7 @@ class ConditionProcessor:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert less than or equal
:param actual_value: actual value

View File

@ -1,5 +1,6 @@
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, cast
from flask import current_app
@ -12,15 +13,18 @@ from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iterable_node import IterableNodeMixin
from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
@ -36,7 +40,6 @@ from extensions.ext_database import db
from models.workflow import (
Workflow,
WorkflowNodeExecutionStatus,
WorkflowType,
)
node_classes = {
@ -60,7 +63,7 @@ node_classes = {
logger = logging.getLogger(__name__)
class WorkflowEngineManager:
class WorkflowEntry:
def run_workflow(self, workflow: Workflow,
user_id: str,
user_from: UserFrom,
@ -69,7 +72,7 @@ class WorkflowEngineManager:
user_inputs: dict,
system_inputs: dict[SystemVariable, Any],
call_depth: int = 0,
variable_pool: Optional[VariablePool] = None) -> None:
variable_pool: Optional[VariablePool] = None) -> Generator:
"""
:param workflow: Workflow instance
:param user_id: user id
@ -102,25 +105,25 @@ class WorkflowEngineManager:
user_inputs=user_inputs
)
# fetch max call depth
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
workflow_call_max_depth = cast(int, workflow_call_max_depth)
if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init graph
graph = self._init_graph(
graph_config=graph_config
)
# init workflow runtime state
workflow_runtime_state = WorkflowRuntimeState(
if not graph:
raise ValueError('graph not found in workflow')
# init workflow run state
graph_engine = GraphEngine(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
user_id=user_id,
user_from=user_from,
variable_pool=variable_pool,
invoke_from=invoke_from,
graph=graph_config,
call_depth=call_depth,
start_at=time.perf_counter()
graph=graph,
variable_pool=variable_pool,
callbacks=callbacks
)
# init workflow run
@ -130,11 +133,7 @@ class WorkflowEngineManager:
try:
# run workflow
self._run_workflow(
graph_config=graph_config,
workflow_runtime_state=workflow_runtime_state,
callbacks=callbacks,
)
rst = graph_engine.run()
except WorkflowRunFailedError as e:
self._workflow_run_failed(
error=e.error,
@ -151,6 +150,8 @@ class WorkflowEngineManager:
callbacks=callbacks
)
return rst
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
"""
Initialize graph
@ -259,23 +260,51 @@ class WorkflowEngineManager:
sub_graph: Optional[Graph] = None
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]:
target_node_cls = None
if target_node_type:
target_node_cls = node_classes.get(target_node_type)
if not target_node_cls:
raise Exception(f'Node class not found for node type: {target_node_type}')
if target_node_cls and issubclass(target_node_cls, IterableNodeMixin):
# find iteration/loop sub nodes that have no predecessor node
sub_graph_root_node_config = None
for root_node_config in root_node_configs:
if root_node_config.get('parentId') == target_node_id:
sub_graph_root_node_config = root_node_config
break
if sub_graph_root_node_config:
# create sub graph run condition
iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls)
sub_graph_run_condition = RunCondition(
type='condition',
conditions=iterable_node_cls.get_conditions(
node_config=target_node_config
)
)
# create sub graph
sub_graph = Graph.init(
root_node_config=root_node_config
root_node_config=sub_graph_root_node_config,
run_condition=sub_graph_run_condition
)
self._recursively_add_edges(
graph=sub_graph,
source_node_config=root_node_config,
source_node_config=sub_graph_root_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
break
# parse run condition
run_condition = None
if edge_config.get('sourceHandle'):
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
# add edge
graph.add_edge(
@ -283,6 +312,7 @@ class WorkflowEngineManager:
source_node_config=source_node_config,
target_node_config=target_node_config,
target_node_sub_graph=sub_graph,
run_condition=run_condition
)
# recursively add edges

View File

@ -1,5 +1,4 @@
from core.workflow.graph import Graph
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
def test__init_graph():
@ -217,18 +216,17 @@ def test__init_graph():
],
}
workflow_engine_manager = WorkflowEngineManager()
graph = workflow_engine_manager._init_graph(
workflow_entry = WorkflowEntry()
graph = workflow_entry._init_graph(
graph_config=graph_config
)
assert graph.root_node.id == "1717222650545"
assert graph.root_node.source_edge_config is None
assert graph.root_node.target_edge_config is not None
assert graph.root_node.descendant_node_ids == ["1719481290322"]
assert graph.graph_nodes.get("1719481290322") is not None
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
assert graph.graph_nodes.get("llm").run_condition_callback is not None
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None
assert graph.graph_nodes.get("llm").run_condition is not None
assert graph.graph_nodes.get("1719481315734").run_condition is not None