add new graph structure

This commit is contained in:
takatost 2024-06-24 23:34:42 +08:00
parent c5d64baba4
commit 8217c46116
6 changed files with 463 additions and 217 deletions

View File

@ -103,6 +103,7 @@ class AdvancedChatAppRunner(AppRunner):
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER, else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
user_inputs=inputs, user_inputs=inputs,
system_inputs={ system_inputs={
SystemVariable.QUERY: query, SystemVariable.QUERY: query,
@ -110,7 +111,6 @@ class AdvancedChatAppRunner(AppRunner):
SystemVariable.CONVERSATION_ID: conversation.id, SystemVariable.CONVERSATION_ID: conversation.id,
SystemVariable.USER_ID: user_id SystemVariable.USER_ID: user_id
}, },
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth call_depth=application_generate_entity.call_depth
) )

View File

@ -74,12 +74,12 @@ class WorkflowAppRunner:
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER, else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
user_inputs=inputs, user_inputs=inputs,
system_inputs={ system_inputs={
SystemVariable.FILES: files, SystemVariable.FILES: files,
SystemVariable.USER_ID: user_id SystemVariable.USER_ID: user_id
}, },
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth call_depth=application_generate_entity.call_depth
) )

View File

@ -66,8 +66,7 @@ class WorkflowRunState:
self.variable_pool = variable_pool self.variable_pool = variable_pool
self.total_tokens = 0 self.total_tokens = 0
self.workflow_nodes_and_results = []
self.current_iteration_state = None
self.workflow_node_steps = 1 self.workflow_node_steps = 1
self.workflow_node_runs = [] self.workflow_node_runs = []
self.current_iteration_state = None

164
api/core/workflow/graph.py Normal file
View File

@ -0,0 +1,164 @@
from collections.abc import Callable
from typing import Optional
from pydantic import BaseModel
class GraphNode(BaseModel):
id: str
"""node id"""
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
children_node_ids: list[str] = []
"""children node ids"""
source_handle: Optional[str] = None
"""current node source handle from the previous node result"""
is_continue_callback: Optional[Callable] = None
"""condition function check if the node can be executed"""
node_config: dict
"""original node config"""
source_edge_config: Optional[dict] = None
"""original source edge config"""
target_edge_config: Optional[dict] = None
"""original target edge config"""
sub_graph: Optional["Graph"] = None
"""sub graph for iteration or loop node"""
def add_child(self, node_id: str) -> None:
self.children_node_ids.append(node_id)
class Graph(BaseModel):
graph_nodes: dict[str, GraphNode] = {}
"""graph nodes"""
root_node: Optional[GraphNode] = None
"""root node of the graph"""
def add_edge(self, edge_config: dict,
source_node_config: dict,
target_node_config: dict,
source_node_sub_graph: Optional["Graph"] = None,
is_continue_callback: Optional[Callable] = None) -> None:
"""
Add edge to the graph
:param edge_config: edge config
:param source_node_config: source node config
:param target_node_config: target node config
:param source_node_sub_graph: sub graph for iteration or loop node
:param is_continue_callback: condition callback
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
target_node_id = target_node_config.get('id')
if not target_node_id:
return
if source_node_id not in self.graph_nodes:
source_graph_node = GraphNode(
id=source_node_id,
node_config=source_node_config,
children_node_ids=[target_node_id],
target_edge_config=edge_config,
sub_graph=source_node_sub_graph
)
self.add_graph_node(source_graph_node)
else:
source_node = self.graph_nodes[source_node_id]
source_node.add_child(target_node_id)
source_node.target_edge_config = edge_config
source_node.sub_graph = source_node_sub_graph
source_handle = None
if edge_config.get('sourceHandle'):
source_handle = edge_config.get('sourceHandle')
if target_node_id not in self.graph_nodes:
target_graph_node = GraphNode(
id=target_node_id,
predecessor_node_id=source_node_id,
node_config=target_node_config,
source_handle=source_handle,
is_continue_callback=is_continue_callback,
source_edge_config=edge_config,
)
self.add_graph_node(target_graph_node)
else:
target_node = self.graph_nodes[target_node_id]
target_node.predecessor_node_id = source_node_id
target_node.source_handle = source_handle
target_node.is_continue_callback = is_continue_callback
target_node.source_edge_config = edge_config
def add_graph_node(self, graph_node: GraphNode) -> None:
"""
Add graph node to the graph
:param graph_node: graph node
"""
if graph_node.id in self.graph_nodes:
return
if len(self.graph_nodes) == 0:
self.root_node = graph_node
self.graph_nodes[graph_node.id] = graph_node
def get_root_node(self) -> Optional[GraphNode]:
"""
Get root node of the graph
:return: root node
"""
return self.root_node
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
"""
Get descendants graph of the specific node
:param node_id: node id
:return: descendants graph
"""
if node_id not in self.graph_nodes:
return None
graph_node = self.graph_nodes[node_id]
if not graph_node.children_node_ids:
return None
descendants_graph = Graph()
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
return descendants_graph
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
"""
Add descendants graph nodes
:param descendants_graph: descendants graph
:param node_id: node id
"""
if node_id not in self.graph_nodes:
return
graph_node = self.graph_nodes[node_id]
descendants_graph.add_graph_node(graph_node)
for child_node_id in graph_node.children_node_ids:
self._add_descendants_graph_nodes(descendants_graph, child_node_id)

View File

@ -1,6 +1,7 @@
import logging import logging
import threading
import time import time
from typing import Optional, cast from typing import Any, Optional, cast
from flask import current_app from flask import current_app
@ -9,7 +10,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
@ -49,7 +50,7 @@ node_classes = {
NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode, NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode, NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
} }
@ -58,67 +59,40 @@ logger = logging.getLogger(__name__)
class WorkflowEngineManager: class WorkflowEngineManager:
def get_default_configs(self) -> list[dict]:
"""
Get default block configs
"""
default_block_configs = []
for node_type, node_class in node_classes.items():
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]:
"""
Get default config of node.
:param node_type: node type
:param filters: filter by node config parameters.
:return:
"""
node_class = node_classes.get(node_type)
if not node_class:
return None
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_workflow(self, workflow: Workflow, def run_workflow(self, workflow: Workflow,
user_id: str, user_id: str,
user_from: UserFrom, user_from: UserFrom,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
callbacks: list[BaseWorkflowCallback],
user_inputs: dict, user_inputs: dict,
system_inputs: Optional[dict] = None, system_inputs: dict[SystemVariable, Any],
callbacks: list[BaseWorkflowCallback] = None, call_depth: int = 0,
call_depth: Optional[int] = 0,
variable_pool: Optional[VariablePool] = None) -> None: variable_pool: Optional[VariablePool] = None) -> None:
""" """
:param workflow: Workflow instance :param workflow: Workflow instance
:param user_id: user id :param user_id: user id
:param user_from: user from :param user_from: user from
:param invoke_from: invoke from service-api, web-app, debugger, explore
:param callbacks: workflow callbacks
:param user_inputs: user variables inputs :param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files :param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:param call_depth: call depth :param call_depth: call depth
:param variable_pool: variable pool
""" """
# fetch workflow graph # fetch workflow graph
graph = workflow.graph_dict graph_dict = workflow.graph_dict
if not graph: if not graph_dict:
raise ValueError('workflow graph not found') raise ValueError('workflow graph not found')
if 'nodes' not in graph or 'edges' not in graph: if 'nodes' not in graph_dict or 'edges' not in graph_dict:
raise ValueError('nodes or edges not found in workflow graph') raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph.get('nodes'), list): if not isinstance(graph_dict.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list') raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph.get('edges'), list): if not isinstance(graph_dict.get('edges'), list):
raise ValueError('edges in workflow graph must be a list') raise ValueError('edges in workflow graph must be a list')
# init variable pool # init variable pool
if not variable_pool: if not variable_pool:
variable_pool = VariablePool( variable_pool = VariablePool(
@ -126,7 +100,9 @@ class WorkflowEngineManager:
user_inputs=user_inputs user_inputs=user_inputs
) )
# fetch max call depth
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH") workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
workflow_call_max_depth = cast(int, workflow_call_max_depth)
if call_depth > workflow_call_max_depth: if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
@ -142,55 +118,55 @@ class WorkflowEngineManager:
) )
# init workflow run # init workflow run
if callbacks: self._workflow_run_started(
for callback in callbacks: callbacks=callbacks
callback.on_workflow_run_started() )
# run workflow # run workflow
self._run_workflow( self._run_workflow(
workflow=workflow, graph=graph_dict,
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
callbacks=callbacks, callbacks=callbacks,
) )
def _run_workflow(self, workflow: Workflow, # workflow run success
workflow_run_state: WorkflowRunState, self._workflow_run_success(
callbacks: list[BaseWorkflowCallback] = None, callbacks=callbacks
start_at: Optional[str] = None, )
end_at: Optional[str] = None) -> None:
def _run_workflow(self, graph: dict,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback],
start_node: Optional[str] = None,
end_node: Optional[str] = None) -> None:
""" """
Run workflow Run workflow
:param workflow: Workflow instance :param graph: workflow graph
:param user_id: user id :param workflow_run_state: workflow run state
:param user_from: user from
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks :param callbacks: workflow callbacks
:param call_depth: call depth :param start_node: force specific start node (gte)
:param start_at: force specific start node :param end_node: force specific end node (le)
:param end_at: force specific end node
:return: :return:
""" """
graph = workflow.graph_dict
try: try:
predecessor_node: BaseNode = None predecessor_node: Optional[BaseNode] = None
current_iteration_node: BaseIterationNode = None current_iteration_node: Optional[BaseIterationNode] = None
has_entry_node = False
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
max_execution_steps = cast(int, max_execution_steps)
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
max_execution_time = cast(int, max_execution_time)
while True: while True:
# get next node, multiple target nodes in the future # get next nodes
next_node = self._get_next_overall_node( next_nodes = self._get_next_overall_nodes(
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
graph=graph, graph=graph,
predecessor_node=predecessor_node, predecessor_node=predecessor_node,
callbacks=callbacks, callbacks=callbacks,
start_at=start_at, node_start_at=start_node,
end_at=end_at node_end_at=end_node
) )
if not next_node: if not next_nodes:
# reached loop/iteration end or overall end # reached loop/iteration end or overall end
if current_iteration_node and workflow_run_state.current_iteration_state: if current_iteration_node and workflow_run_state.current_iteration_state:
# reached loop/iteration end # reached loop/iteration end
@ -221,13 +197,13 @@ class WorkflowEngineManager:
callbacks=callbacks callbacks=callbacks
) )
# iteration has ended # iteration has ended
next_node = self._get_next_overall_node( next_nodes = self._get_next_overall_nodes(
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
graph=graph, graph=graph,
predecessor_node=current_iteration_node, predecessor_node=current_iteration_node,
callbacks=callbacks, callbacks=callbacks,
start_at=start_at, node_start_at=start_node,
end_at=end_at node_end_at=end_node
) )
current_iteration_node = None current_iteration_node = None
workflow_run_state.current_iteration_state = None workflow_run_state.current_iteration_state = None
@ -236,18 +212,11 @@ class WorkflowEngineManager:
# move to next iteration # move to next iteration
next_node_id = next_iteration next_node_id = next_iteration
# get next id # get next id
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) next_nodes = [self._get_node(workflow_run_state, graph, next_node_id, callbacks)]
if not next_node: if not next_nodes:
break break
# check is already ran
if self._check_node_has_ran(workflow_run_state, next_node.node_id):
predecessor_node = next_node
continue
has_entry_node = True
# max steps reached # max steps reached
if workflow_run_state.workflow_node_steps > max_execution_steps: if workflow_run_state.workflow_node_steps > max_execution_steps:
raise ValueError('Max steps {} reached.'.format(max_execution_steps)) raise ValueError('Max steps {} reached.'.format(max_execution_steps))
@ -256,62 +225,40 @@ class WorkflowEngineManager:
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
# handle iteration nodes if len(next_nodes) == 1:
if isinstance(next_node, BaseIterationNode): next_node = next_nodes[0]
current_iteration_node = next_node
workflow_run_state.current_iteration_state = next_node.run( # run node
variable_pool=workflow_run_state.variable_pool is_continue = self._run_node(
)
self._workflow_iteration_started(
graph=graph, graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None, predecessor_node=predecessor_node,
current_node=next_node,
callbacks=callbacks callbacks=callbacks
) )
if not is_continue:
break
predecessor_node = next_node predecessor_node = next_node
# move to start node of iteration else:
next_node_id = next_node.get_next_iteration( result_dict = {}
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_next(
graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
if isinstance(next_node_id, NodeRunResult):
# iteration has ended
current_iteration_node.set_output(
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_completed(
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
current_iteration_node = None
workflow_run_state.current_iteration_state = None
continue
else:
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
# run workflow, run multiple target nodes in the future # new thread
self._run_workflow_node( worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
workflow_run_state=workflow_run_state, 'flask_app': current_app._get_current_object(),
node=next_node, 'graph': graph,
predecessor_node=predecessor_node, 'workflow_run_state': workflow_run_state,
callbacks=callbacks 'predecessor_node': predecessor_node,
) 'next_nodes': next_nodes,
'callbacks': callbacks,
'result': result_dict
})
if next_node.node_type in [NodeType.END]: worker_thread.start()
break worker_thread.join()
predecessor_node = next_node if not workflow_run_state.workflow_node_runs:
if not has_entry_node:
self._workflow_run_failed( self._workflow_run_failed(
error='Start node not found in workflow graph.', error='Start node not found in workflow graph.',
callbacks=callbacks callbacks=callbacks
@ -326,11 +273,109 @@ class WorkflowEngineManager:
) )
return return
# workflow run success def _async_run_nodes(self, flask_app: Flask,
self._workflow_run_success( graph: dict,
workflow_run_state: WorkflowRunState,
predecessor_node: Optional[BaseNode],
next_nodes: list[BaseNode],
callbacks: list[BaseWorkflowCallback],
result: dict):
with flask_app.app_context():
try:
for next_node in next_nodes:
# TODO run sub workflows
# run node
is_continue = self._run_node(
graph=graph,
workflow_run_state=workflow_run_state,
predecessor_node=predecessor_node,
current_node=next_node,
callbacks=callbacks
)
if not is_continue:
break
predecessor_node = next_node
except Exception as e:
logger.exception("Unknown Error when generating")
finally:
db.session.remove()
def _run_node(self, graph: dict,
workflow_run_state: WorkflowRunState,
predecessor_node: Optional[BaseNode],
current_node: BaseNode,
callbacks: list[BaseWorkflowCallback]) -> bool:
"""
Run node
:param graph: workflow graph
:param workflow_run_state: current workflow run state
:param predecessor_node: predecessor node
:param current_node: current node for run
:param callbacks: workflow callbacks
:return: continue?
"""
# check is already ran
if self._check_node_has_ran(workflow_run_state, current_node.node_id):
return True
# handle iteration nodes
if isinstance(current_node, BaseIterationNode):
current_iteration_node = current_node
workflow_run_state.current_iteration_state = current_node.run(
variable_pool=workflow_run_state.variable_pool
)
self._workflow_iteration_started(
graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
callbacks=callbacks
)
predecessor_node = current_node
# move to start node of iteration
current_node_id = current_node.get_next_iteration(
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_next(
graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
if isinstance(current_node_id, NodeRunResult):
# iteration has ended
current_iteration_node.set_output(
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_completed(
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
current_iteration_node = None
workflow_run_state.current_iteration_state = None
return True
else:
# fetch next node in iteration
current_node = self._get_node(workflow_run_state, graph, current_node_id, callbacks)
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
workflow_run_state=workflow_run_state,
node=current_node,
predecessor_node=predecessor_node,
callbacks=callbacks callbacks=callbacks
) )
if current_node.node_type in [NodeType.END]:
return False
return True
def single_step_run_workflow_node(self, workflow: Workflow, def single_step_run_workflow_node(self, workflow: Workflow,
node_id: str, node_id: str,
user_id: str, user_id: str,
@ -398,7 +443,7 @@ class WorkflowEngineManager:
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
node_instance=node_instance node_instance=node_instance
) )
# run node # run node
node_run_result = node_instance.run( node_run_result = node_instance.run(
variable_pool=variable_pool variable_pool=variable_pool
@ -417,11 +462,11 @@ class WorkflowEngineManager:
return node_instance, node_run_result return node_instance, node_run_result
def single_step_run_iteration_workflow_node(self, workflow: Workflow, def single_step_run_iteration_workflow_node(self, workflow: Workflow,
node_id: str, node_id: str,
user_id: str, user_id: str,
user_inputs: dict, user_inputs: dict,
callbacks: list[BaseWorkflowCallback] = None, callbacks: list[BaseWorkflowCallback] = None,
) -> None: ) -> None:
""" """
Single iteration run workflow node Single iteration run workflow node
""" """
@ -443,7 +488,7 @@ class WorkflowEngineManager:
node_config = node node_config = node
else: else:
raise ValueError('node id is not an iteration node') raise ValueError('node id is not an iteration node')
# init variable pool # init variable pool
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={}, system_variables={},
@ -452,7 +497,7 @@ class WorkflowEngineManager:
# variable selector to variable mapping # variable selector to variable mapping
iteration_nested_nodes = [ iteration_nested_nodes = [
node for node in nodes node for node in nodes
if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id
] ]
iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes]
@ -475,13 +520,13 @@ class WorkflowEngineManager:
# remove iteration variables # remove iteration variables
variable_mapping = { variable_mapping = {
f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items()
if value[0] != node_id if value[0] != node_id
} }
# remove variable out from iteration # remove variable out from iteration
variable_mapping = { variable_mapping = {
key: value for key, value in variable_mapping.items() key: value for key, value in variable_mapping.items()
if value[0] not in iteration_nested_node_ids if value[0] not in iteration_nested_node_ids
} }
@ -529,13 +574,29 @@ class WorkflowEngineManager:
# run workflow # run workflow
self._run_workflow( self._run_workflow(
workflow=workflow, graph=workflow.graph,
workflow_run_state=workflow_run_state, workflow_run_state=workflow_run_state,
callbacks=callbacks, callbacks=callbacks,
start_at=node_id, start_node=node_id,
end_at=end_node_id end_node=end_node_id
) )
# workflow run success
self._workflow_run_success(
callbacks=callbacks
)
def _workflow_run_started(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run started
:param callbacks: workflow callbacks
:return:
"""
# init workflow run
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started()
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
""" """
Workflow run success Workflow run success
@ -561,7 +622,7 @@ class WorkflowEngineManager:
error=error error=error
) )
def _workflow_iteration_started(self, graph: dict, def _workflow_iteration_started(self, graph: dict,
current_iteration_node: BaseIterationNode, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState, workflow_run_state: WorkflowRunState,
predecessor_node_id: Optional[str] = None, predecessor_node_id: Optional[str] = None,
@ -598,9 +659,9 @@ class WorkflowEngineManager:
# add steps # add steps
workflow_run_state.workflow_node_steps += 1 workflow_run_state.workflow_node_steps += 1
def _workflow_iteration_next(self, graph: dict, def _workflow_iteration_next(self, graph: dict,
current_iteration_node: BaseIterationNode, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState, workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None: callbacks: list[BaseWorkflowCallback] = None) -> None:
""" """
Workflow iteration next Workflow iteration next
@ -629,10 +690,10 @@ class WorkflowEngineManager:
for node in nodes: for node in nodes:
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id')) workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode, def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState, workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None: callbacks: list[BaseWorkflowCallback] = None) -> None:
if callbacks: if callbacks:
if isinstance(workflow_run_state.current_iteration_state, IterationState): if isinstance(workflow_run_state.current_iteration_state, IterationState):
for callback in callbacks: for callback in callbacks:
@ -645,35 +706,39 @@ class WorkflowEngineManager:
} }
) )
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState, def _get_next_overall_nodes(self, workflow_run_state: WorkflowRunState,
graph: dict, graph: dict,
predecessor_node: Optional[BaseNode] = None, callbacks: list[BaseWorkflowCallback],
callbacks: list[BaseWorkflowCallback] = None, predecessor_node: Optional[BaseNode] = None,
start_at: Optional[str] = None, node_start_at: Optional[str] = None,
end_at: Optional[str] = None) -> Optional[BaseNode]: node_end_at: Optional[str] = None) -> list[BaseNode]:
""" """
Get next node Get next nodes
multiple target nodes in the future. multiple target nodes in the future.
:param graph: workflow graph :param graph: workflow graph
:param predecessor_node: predecessor node
:param callbacks: workflow callbacks :param callbacks: workflow callbacks
:return: :param predecessor_node: predecessor node
:param node_start_at: force specific start node
:param node_end_at: force specific end node
:return: target node list
""" """
nodes = graph.get('nodes') nodes = graph.get('nodes')
if not nodes: if not nodes:
return None return []
if not predecessor_node: if not predecessor_node:
# fetch start node
for node_config in nodes: for node_config in nodes:
node_cls = None node_cls = None
if start_at: if node_start_at:
if node_config.get('id') == start_at: if node_config.get('id') == node_start_at:
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
else: else:
if node_config.get('data', {}).get('type', '') == NodeType.START.value: if node_config.get('data', {}).get('type', '') == NodeType.START.value:
node_cls = StartNode node_cls = StartNode
if node_cls: if node_cls:
return node_cls( return [node_cls(
tenant_id=workflow_run_state.tenant_id, tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id, app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id, workflow_id=workflow_run_state.workflow_id,
@ -683,64 +748,73 @@ class WorkflowEngineManager:
config=node_config, config=node_config,
callbacks=callbacks, callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth workflow_call_depth=workflow_run_state.workflow_call_depth
) )]
return []
else: else:
edges = graph.get('edges') edges = graph.get('edges')
edges = cast(list, edges)
source_node_id = predecessor_node.node_id source_node_id = predecessor_node.node_id
# fetch all outgoing edges from source node # fetch all outgoing edges from source node
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
if not outgoing_edges: if not outgoing_edges:
return None return []
# fetch target node id from outgoing edges # fetch target node ids from outgoing edges
outgoing_edge = None target_edges = []
source_handle = predecessor_node.node_run_result.edge_source_handle \ source_handle = predecessor_node.node_run_result.edge_source_handle \
if predecessor_node.node_run_result else None if predecessor_node.node_run_result else None
if source_handle: if source_handle:
for edge in outgoing_edges: for edge in outgoing_edges:
if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle:
outgoing_edge = edge target_edges.append(edge)
break
else: else:
outgoing_edge = outgoing_edges[0] target_edges = outgoing_edges
if not outgoing_edge: if not target_edges:
return None return []
target_node_id = outgoing_edge.get('target') target_nodes = []
for target_edge in target_edges:
target_node_id = target_edge.get('target')
if end_at and target_node_id == end_at: if node_end_at and target_node_id == node_end_at:
return None continue
# fetch target node from target node id # fetch target node from target node id
target_node_config = None target_node_config = None
for node in nodes: for node in nodes:
if node.get('id') == target_node_id: if node.get('id') == target_node_id:
target_node_config = node target_node_config = node
break break
if not target_node_config: if not target_node_config:
return None continue
# get next node # get next node
target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) target_node_cls = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
if not target_node_cls:
continue
return target_node( target_node = target_node_cls(
tenant_id=workflow_run_state.tenant_id, tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id, app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id, workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id, user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from, user_from=workflow_run_state.user_from,
invoke_from=workflow_run_state.invoke_from, invoke_from=workflow_run_state.invoke_from,
config=target_node_config, config=target_node_config,
callbacks=callbacks, callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth workflow_call_depth=workflow_run_state.workflow_call_depth
) )
def _get_node(self, workflow_run_state: WorkflowRunState, target_nodes.append(target_node)
graph: dict,
return target_nodes
def _get_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
node_id: str, node_id: str,
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]: callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
""" """
@ -807,9 +881,6 @@ class WorkflowEngineManager:
result=None result=None
) )
# add to workflow_nodes_and_results
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
# add steps # add steps
workflow_run_state.workflow_node_steps += 1 workflow_run_state.workflow_node_steps += 1
@ -940,7 +1011,7 @@ class WorkflowEngineManager:
return new_value return new_value
def _mapping_user_inputs_to_variable_pool(self, def _mapping_user_inputs_to_variable_pool(self,
variable_mapping: dict, variable_mapping: dict,
user_inputs: dict, user_inputs: dict,
variable_pool: VariablePool, variable_pool: VariablePool,
@ -988,4 +1059,4 @@ class WorkflowEngineManager:
node_id=variable_node_id, node_id=variable_node_id,
variable_key_list=variable_key_list, variable_key_list=variable_key_list,
value=value value=value
) )

View File

@ -8,7 +8,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
@ -159,8 +159,13 @@ class WorkflowService:
Get default block configs Get default block configs
""" """
# return default block config # return default block config
workflow_engine_manager = WorkflowEngineManager() default_block_configs = []
return workflow_engine_manager.get_default_configs() for node_type, node_class in node_classes.items():
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
""" """
@ -169,11 +174,18 @@ class WorkflowService:
:param filters: filter by node config parameters. :param filters: filter by node config parameters.
:return: :return:
""" """
node_type = NodeType.value_of(node_type) node_type_enum: NodeType = NodeType.value_of(node_type)
# return default block config # return default block config
workflow_engine_manager = WorkflowEngineManager() node_class = node_classes.get(node_type_enum)
return workflow_engine_manager.get_default_config(node_type, filters) if not node_class:
return None
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_draft_workflow_node(self, app_model: App, def run_draft_workflow_node(self, app_model: App,
node_id: str, node_id: str,