diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index de3632894d..a4ee47cc1f 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -103,6 +103,7 @@ class AdvancedChatAppRunner(AppRunner): if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, + callbacks=workflow_callbacks, user_inputs=inputs, system_inputs={ SystemVariable.QUERY: query, @@ -110,7 +111,6 @@ class AdvancedChatAppRunner(AppRunner): SystemVariable.CONVERSATION_ID: conversation.id, SystemVariable.USER_ID: user_id }, - callbacks=workflow_callbacks, call_depth=application_generate_entity.call_depth ) diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 050319e552..36e2deb42d 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -74,12 +74,12 @@ class WorkflowAppRunner: if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else UserFrom.END_USER, invoke_from=application_generate_entity.invoke_from, + callbacks=workflow_callbacks, user_inputs=inputs, system_inputs={ SystemVariable.FILES: files, SystemVariable.USER_ID: user_id }, - callbacks=workflow_callbacks, call_depth=application_generate_entity.call_depth ) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py index 9b35b8df8a..4bf4e454bb 100644 --- a/api/core/workflow/entities/workflow_entities.py +++ b/api/core/workflow/entities/workflow_entities.py @@ -66,8 +66,7 @@ class WorkflowRunState: self.variable_pool = variable_pool self.total_tokens = 0 - self.workflow_nodes_and_results = [] - self.current_iteration_state = None self.workflow_node_steps = 1 - self.workflow_node_runs = [] \ No newline at end of file + self.workflow_node_runs = [] + self.current_iteration_state = None diff --git a/api/core/workflow/graph.py b/api/core/workflow/graph.py new file mode 100644 index 0000000000..10931c0be4 --- /dev/null +++ b/api/core/workflow/graph.py @@ -0,0 +1,164 @@ +from collections.abc import Callable +from typing import Optional + +from pydantic import BaseModel + + +class GraphNode(BaseModel): + id: str + """node id""" + + predecessor_node_id: Optional[str] = None + """predecessor node id""" + + children_node_ids: list[str] = [] + """children node ids""" + + source_handle: Optional[str] = None + """current node source handle from the previous node result""" + + is_continue_callback: Optional[Callable] = None + """condition function check if the node can be executed""" + + node_config: dict + """original node config""" + + source_edge_config: Optional[dict] = None + """original source edge config""" + + target_edge_config: Optional[dict] = None + """original target edge config""" + + sub_graph: Optional["Graph"] = None + """sub graph for iteration or loop node""" + + def add_child(self, node_id: str) -> None: + self.children_node_ids.append(node_id) + + +class Graph(BaseModel): + graph_nodes: dict[str, GraphNode] = {} + """graph nodes""" + + root_node: Optional[GraphNode] = None + """root node of the graph""" + + def add_edge(self, edge_config: dict, + source_node_config: dict, + target_node_config: dict, + source_node_sub_graph: Optional["Graph"] = None, + is_continue_callback: Optional[Callable] = None) -> None: + """ + Add edge to the graph + + :param edge_config: edge config + :param source_node_config: source node config + :param target_node_config: target node config + :param source_node_sub_graph: sub graph for iteration or loop node + :param is_continue_callback: condition callback + """ + source_node_id = source_node_config.get('id') + if not source_node_id: + return + + target_node_id = target_node_config.get('id') + if not target_node_id: + return + + if source_node_id not in self.graph_nodes: + source_graph_node = GraphNode( + id=source_node_id, + node_config=source_node_config, + children_node_ids=[target_node_id], + target_edge_config=edge_config, + sub_graph=source_node_sub_graph + ) + + self.add_graph_node(source_graph_node) + else: + source_node = self.graph_nodes[source_node_id] + source_node.add_child(target_node_id) + source_node.target_edge_config = edge_config + source_node.sub_graph = source_node_sub_graph + + source_handle = None + if edge_config.get('sourceHandle'): + source_handle = edge_config.get('sourceHandle') + + if target_node_id not in self.graph_nodes: + target_graph_node = GraphNode( + id=target_node_id, + predecessor_node_id=source_node_id, + node_config=target_node_config, + source_handle=source_handle, + is_continue_callback=is_continue_callback, + source_edge_config=edge_config, + ) + + self.add_graph_node(target_graph_node) + else: + target_node = self.graph_nodes[target_node_id] + target_node.predecessor_node_id = source_node_id + target_node.source_handle = source_handle + target_node.is_continue_callback = is_continue_callback + target_node.source_edge_config = edge_config + + def add_graph_node(self, graph_node: GraphNode) -> None: + """ + Add graph node to the graph + + :param graph_node: graph node + """ + if graph_node.id in self.graph_nodes: + return + + if len(self.graph_nodes) == 0: + self.root_node = graph_node + + self.graph_nodes[graph_node.id] = graph_node + + def get_root_node(self) -> Optional[GraphNode]: + """ + Get root node of the graph + + :return: root node + """ + return self.root_node + + def get_descendants_graph(self, node_id: str) -> Optional["Graph"]: + """ + Get descendants graph of the specific node + + :param node_id: node id + :return: descendants graph + """ + if node_id not in self.graph_nodes: + return None + + graph_node = self.graph_nodes[node_id] + if not graph_node.children_node_ids: + return None + + descendants_graph = Graph() + descendants_graph.add_graph_node(graph_node) + + for child_node_id in graph_node.children_node_ids: + self._add_descendants_graph_nodes(descendants_graph, child_node_id) + + return descendants_graph + + def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None: + """ + Add descendants graph nodes + + :param descendants_graph: descendants graph + :param node_id: node id + """ + if node_id not in self.graph_nodes: + return + + graph_node = self.graph_nodes[node_id] + descendants_graph.add_graph_node(graph_node) + + for child_node_id in graph_node.children_node_ids: + self._add_descendants_graph_nodes(descendants_graph, child_node_id) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 22deafb8a3..4007af85a1 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -1,6 +1,7 @@ import logging +import threading import time -from typing import Optional, cast +from typing import Any, Optional, cast from flask import current_app @@ -9,7 +10,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState from core.workflow.errors import WorkflowNodeRunFailedError @@ -49,7 +50,7 @@ node_classes = { NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.TOOL: ToolNode, NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR NodeType.ITERATION: IterationNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode } @@ -58,67 +59,40 @@ logger = logging.getLogger(__name__) class WorkflowEngineManager: - def get_default_configs(self) -> list[dict]: - """ - Get default block configs - """ - default_block_configs = [] - for node_type, node_class in node_classes.items(): - default_config = node_class.get_default_config() - if default_config: - default_block_configs.append(default_config) - - return default_block_configs - - def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]: - """ - Get default config of node. - :param node_type: node type - :param filters: filter by node config parameters. - :return: - """ - node_class = node_classes.get(node_type) - if not node_class: - return None - - default_config = node_class.get_default_config(filters=filters) - if not default_config: - return None - - return default_config - def run_workflow(self, workflow: Workflow, user_id: str, user_from: UserFrom, invoke_from: InvokeFrom, + callbacks: list[BaseWorkflowCallback], user_inputs: dict, - system_inputs: Optional[dict] = None, - callbacks: list[BaseWorkflowCallback] = None, - call_depth: Optional[int] = 0, + system_inputs: dict[SystemVariable, Any], + call_depth: int = 0, variable_pool: Optional[VariablePool] = None) -> None: """ :param workflow: Workflow instance :param user_id: user id :param user_from: user from + :param invoke_from: invoke from service-api, web-app, debugger, explore + :param callbacks: workflow callbacks :param user_inputs: user variables inputs :param system_inputs: system inputs, like: query, files - :param callbacks: workflow callbacks :param call_depth: call depth + :param variable_pool: variable pool """ # fetch workflow graph - graph = workflow.graph_dict - if not graph: + graph_dict = workflow.graph_dict + if not graph_dict: raise ValueError('workflow graph not found') - if 'nodes' not in graph or 'edges' not in graph: + if 'nodes' not in graph_dict or 'edges' not in graph_dict: raise ValueError('nodes or edges not found in workflow graph') - if not isinstance(graph.get('nodes'), list): + if not isinstance(graph_dict.get('nodes'), list): raise ValueError('nodes in workflow graph must be a list') - if not isinstance(graph.get('edges'), list): + if not isinstance(graph_dict.get('edges'), list): raise ValueError('edges in workflow graph must be a list') - + # init variable pool if not variable_pool: variable_pool = VariablePool( @@ -126,7 +100,9 @@ class WorkflowEngineManager: user_inputs=user_inputs ) + # fetch max call depth workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH") + workflow_call_max_depth = cast(int, workflow_call_max_depth) if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) @@ -142,55 +118,55 @@ class WorkflowEngineManager: ) # init workflow run - if callbacks: - for callback in callbacks: - callback.on_workflow_run_started() + self._workflow_run_started( + callbacks=callbacks + ) # run workflow self._run_workflow( - workflow=workflow, + graph=graph_dict, workflow_run_state=workflow_run_state, callbacks=callbacks, ) - def _run_workflow(self, workflow: Workflow, - workflow_run_state: WorkflowRunState, - callbacks: list[BaseWorkflowCallback] = None, - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> None: + # workflow run success + self._workflow_run_success( + callbacks=callbacks + ) + + def _run_workflow(self, graph: dict, + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback], + start_node: Optional[str] = None, + end_node: Optional[str] = None) -> None: """ Run workflow - :param workflow: Workflow instance - :param user_id: user id - :param user_from: user from - :param user_inputs: user variables inputs - :param system_inputs: system inputs, like: query, files + :param graph: workflow graph + :param workflow_run_state: workflow run state :param callbacks: workflow callbacks - :param call_depth: call depth - :param start_at: force specific start node - :param end_at: force specific end node + :param start_node: force specific start node (gte) + :param end_node: force specific end node (le) :return: """ - graph = workflow.graph_dict - try: - predecessor_node: BaseNode = None - current_iteration_node: BaseIterationNode = None - has_entry_node = False + predecessor_node: Optional[BaseNode] = None + current_iteration_node: Optional[BaseIterationNode] = None max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") + max_execution_steps = cast(int, max_execution_steps) max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") + max_execution_time = cast(int, max_execution_time) while True: - # get next node, multiple target nodes in the future - next_node = self._get_next_overall_node( + # get next nodes + next_nodes = self._get_next_overall_nodes( workflow_run_state=workflow_run_state, graph=graph, predecessor_node=predecessor_node, callbacks=callbacks, - start_at=start_at, - end_at=end_at + node_start_at=start_node, + node_end_at=end_node ) - if not next_node: + if not next_nodes: # reached loop/iteration end or overall end if current_iteration_node and workflow_run_state.current_iteration_state: # reached loop/iteration end @@ -221,13 +197,13 @@ class WorkflowEngineManager: callbacks=callbacks ) # iteration has ended - next_node = self._get_next_overall_node( + next_nodes = self._get_next_overall_nodes( workflow_run_state=workflow_run_state, graph=graph, predecessor_node=current_iteration_node, callbacks=callbacks, - start_at=start_at, - end_at=end_at + node_start_at=start_node, + node_end_at=end_node ) current_iteration_node = None workflow_run_state.current_iteration_state = None @@ -236,18 +212,11 @@ class WorkflowEngineManager: # move to next iteration next_node_id = next_iteration # get next id - next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) - - if not next_node: + next_nodes = [self._get_node(workflow_run_state, graph, next_node_id, callbacks)] + + if not next_nodes: break - # check is already ran - if self._check_node_has_ran(workflow_run_state, next_node.node_id): - predecessor_node = next_node - continue - - has_entry_node = True - # max steps reached if workflow_run_state.workflow_node_steps > max_execution_steps: raise ValueError('Max steps {} reached.'.format(max_execution_steps)) @@ -256,62 +225,40 @@ class WorkflowEngineManager: if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time): raise ValueError('Max execution time {}s reached.'.format(max_execution_time)) - # handle iteration nodes - if isinstance(next_node, BaseIterationNode): - current_iteration_node = next_node - workflow_run_state.current_iteration_state = next_node.run( - variable_pool=workflow_run_state.variable_pool - ) - self._workflow_iteration_started( + if len(next_nodes) == 1: + next_node = next_nodes[0] + + # run node + is_continue = self._run_node( graph=graph, - current_iteration_node=current_iteration_node, workflow_run_state=workflow_run_state, - predecessor_node_id=predecessor_node.node_id if predecessor_node else None, + predecessor_node=predecessor_node, + current_node=next_node, callbacks=callbacks ) + + if not is_continue: + break + predecessor_node = next_node - # move to start node of iteration - next_node_id = next_node.get_next_iteration( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_next( - graph=graph, - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - if isinstance(next_node_id, NodeRunResult): - # iteration has ended - current_iteration_node.set_output( - variable_pool=workflow_run_state.variable_pool, - state=workflow_run_state.current_iteration_state - ) - self._workflow_iteration_completed( - current_iteration_node=current_iteration_node, - workflow_run_state=workflow_run_state, - callbacks=callbacks - ) - current_iteration_node = None - workflow_run_state.current_iteration_state = None - continue - else: - next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks) + else: + result_dict = {} - # run workflow, run multiple target nodes in the future - self._run_workflow_node( - workflow_run_state=workflow_run_state, - node=next_node, - predecessor_node=predecessor_node, - callbacks=callbacks - ) + # new thread + worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={ + 'flask_app': current_app._get_current_object(), + 'graph': graph, + 'workflow_run_state': workflow_run_state, + 'predecessor_node': predecessor_node, + 'next_nodes': next_nodes, + 'callbacks': callbacks, + 'result': result_dict + }) - if next_node.node_type in [NodeType.END]: - break + worker_thread.start() + worker_thread.join() - predecessor_node = next_node - - if not has_entry_node: + if not workflow_run_state.workflow_node_runs: self._workflow_run_failed( error='Start node not found in workflow graph.', callbacks=callbacks @@ -326,11 +273,109 @@ class WorkflowEngineManager: ) return - # workflow run success - self._workflow_run_success( + def _async_run_nodes(self, flask_app: Flask, + graph: dict, + workflow_run_state: WorkflowRunState, + predecessor_node: Optional[BaseNode], + next_nodes: list[BaseNode], + callbacks: list[BaseWorkflowCallback], + result: dict): + with flask_app.app_context(): + try: + for next_node in next_nodes: + # TODO run sub workflows + # run node + is_continue = self._run_node( + graph=graph, + workflow_run_state=workflow_run_state, + predecessor_node=predecessor_node, + current_node=next_node, + callbacks=callbacks + ) + + if not is_continue: + break + + predecessor_node = next_node + except Exception as e: + logger.exception("Unknown Error when generating") + finally: + db.session.remove() + + def _run_node(self, graph: dict, + workflow_run_state: WorkflowRunState, + predecessor_node: Optional[BaseNode], + current_node: BaseNode, + callbacks: list[BaseWorkflowCallback]) -> bool: + """ + Run node + :param graph: workflow graph + :param workflow_run_state: current workflow run state + :param predecessor_node: predecessor node + :param current_node: current node for run + :param callbacks: workflow callbacks + :return: continue? + """ + # check is already ran + if self._check_node_has_ran(workflow_run_state, current_node.node_id): + return True + + # handle iteration nodes + if isinstance(current_node, BaseIterationNode): + current_iteration_node = current_node + workflow_run_state.current_iteration_state = current_node.run( + variable_pool=workflow_run_state.variable_pool + ) + self._workflow_iteration_started( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + predecessor_node_id=predecessor_node.node_id if predecessor_node else None, + callbacks=callbacks + ) + predecessor_node = current_node + # move to start node of iteration + current_node_id = current_node.get_next_iteration( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_next( + graph=graph, + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + if isinstance(current_node_id, NodeRunResult): + # iteration has ended + current_iteration_node.set_output( + variable_pool=workflow_run_state.variable_pool, + state=workflow_run_state.current_iteration_state + ) + self._workflow_iteration_completed( + current_iteration_node=current_iteration_node, + workflow_run_state=workflow_run_state, + callbacks=callbacks + ) + current_iteration_node = None + workflow_run_state.current_iteration_state = None + return True + else: + # fetch next node in iteration + current_node = self._get_node(workflow_run_state, graph, current_node_id, callbacks) + + # run workflow, run multiple target nodes in the future + self._run_workflow_node( + workflow_run_state=workflow_run_state, + node=current_node, + predecessor_node=predecessor_node, callbacks=callbacks ) + if current_node.node_type in [NodeType.END]: + return False + + return True + def single_step_run_workflow_node(self, workflow: Workflow, node_id: str, user_id: str, @@ -398,7 +443,7 @@ class WorkflowEngineManager: tenant_id=workflow.tenant_id, node_instance=node_instance ) - + # run node node_run_result = node_instance.run( variable_pool=variable_pool @@ -417,11 +462,11 @@ class WorkflowEngineManager: return node_instance, node_run_result def single_step_run_iteration_workflow_node(self, workflow: Workflow, - node_id: str, - user_id: str, - user_inputs: dict, - callbacks: list[BaseWorkflowCallback] = None, - ) -> None: + node_id: str, + user_id: str, + user_inputs: dict, + callbacks: list[BaseWorkflowCallback] = None, + ) -> None: """ Single iteration run workflow node """ @@ -443,7 +488,7 @@ class WorkflowEngineManager: node_config = node else: raise ValueError('node id is not an iteration node') - + # init variable pool variable_pool = VariablePool( system_variables={}, @@ -452,7 +497,7 @@ class WorkflowEngineManager: # variable selector to variable mapping iteration_nested_nodes = [ - node for node in nodes + node for node in nodes if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id ] iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes] @@ -475,13 +520,13 @@ class WorkflowEngineManager: # remove iteration variables variable_mapping = { - f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() + f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items() if value[0] != node_id } # remove variable out from iteration variable_mapping = { - key: value for key, value in variable_mapping.items() + key: value for key, value in variable_mapping.items() if value[0] not in iteration_nested_node_ids } @@ -529,13 +574,29 @@ class WorkflowEngineManager: # run workflow self._run_workflow( - workflow=workflow, + graph=workflow.graph, workflow_run_state=workflow_run_state, callbacks=callbacks, - start_at=node_id, - end_at=end_node_id + start_node=node_id, + end_node=end_node_id ) + # workflow run success + self._workflow_run_success( + callbacks=callbacks + ) + + def _workflow_run_started(self, callbacks: list[BaseWorkflowCallback] = None) -> None: + """ + Workflow run started + :param callbacks: workflow callbacks + :return: + """ + # init workflow run + if callbacks: + for callback in callbacks: + callback.on_workflow_run_started() + def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow run success @@ -561,7 +622,7 @@ class WorkflowEngineManager: error=error ) - def _workflow_iteration_started(self, graph: dict, + def _workflow_iteration_started(self, graph: dict, current_iteration_node: BaseIterationNode, workflow_run_state: WorkflowRunState, predecessor_node_id: Optional[str] = None, @@ -598,9 +659,9 @@ class WorkflowEngineManager: # add steps workflow_run_state.workflow_node_steps += 1 - def _workflow_iteration_next(self, graph: dict, + def _workflow_iteration_next(self, graph: dict, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, + workflow_run_state: WorkflowRunState, callbacks: list[BaseWorkflowCallback] = None) -> None: """ Workflow iteration next @@ -629,10 +690,10 @@ class WorkflowEngineManager: for node in nodes: workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id')) - + def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode, - workflow_run_state: WorkflowRunState, - callbacks: list[BaseWorkflowCallback] = None) -> None: + workflow_run_state: WorkflowRunState, + callbacks: list[BaseWorkflowCallback] = None) -> None: if callbacks: if isinstance(workflow_run_state.current_iteration_state, IterationState): for callback in callbacks: @@ -645,35 +706,39 @@ class WorkflowEngineManager: } ) - def _get_next_overall_node(self, workflow_run_state: WorkflowRunState, - graph: dict, - predecessor_node: Optional[BaseNode] = None, - callbacks: list[BaseWorkflowCallback] = None, - start_at: Optional[str] = None, - end_at: Optional[str] = None) -> Optional[BaseNode]: + def _get_next_overall_nodes(self, workflow_run_state: WorkflowRunState, + graph: dict, + callbacks: list[BaseWorkflowCallback], + predecessor_node: Optional[BaseNode] = None, + node_start_at: Optional[str] = None, + node_end_at: Optional[str] = None) -> list[BaseNode]: """ - Get next node + Get next nodes multiple target nodes in the future. :param graph: workflow graph - :param predecessor_node: predecessor node :param callbacks: workflow callbacks - :return: + :param predecessor_node: predecessor node + :param node_start_at: force specific start node + :param node_end_at: force specific end node + :return: target node list """ nodes = graph.get('nodes') if not nodes: - return None + return [] if not predecessor_node: + # fetch start node for node_config in nodes: node_cls = None - if start_at: - if node_config.get('id') == start_at: + if node_start_at: + if node_config.get('id') == node_start_at: node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type'))) else: if node_config.get('data', {}).get('type', '') == NodeType.START.value: node_cls = StartNode + if node_cls: - return node_cls( + return [node_cls( tenant_id=workflow_run_state.tenant_id, app_id=workflow_run_state.app_id, workflow_id=workflow_run_state.workflow_id, @@ -683,64 +748,73 @@ class WorkflowEngineManager: config=node_config, callbacks=callbacks, workflow_call_depth=workflow_run_state.workflow_call_depth - ) - + )] + + return [] else: edges = graph.get('edges') + edges = cast(list, edges) source_node_id = predecessor_node.node_id # fetch all outgoing edges from source node outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id] if not outgoing_edges: - return None + return [] - # fetch target node id from outgoing edges - outgoing_edge = None + # fetch target node ids from outgoing edges + target_edges = [] source_handle = predecessor_node.node_run_result.edge_source_handle \ if predecessor_node.node_run_result else None if source_handle: for edge in outgoing_edges: if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle: - outgoing_edge = edge - break + target_edges.append(edge) else: - outgoing_edge = outgoing_edges[0] + target_edges = outgoing_edges - if not outgoing_edge: - return None + if not target_edges: + return [] - target_node_id = outgoing_edge.get('target') + target_nodes = [] + for target_edge in target_edges: + target_node_id = target_edge.get('target') - if end_at and target_node_id == end_at: - return None + if node_end_at and target_node_id == node_end_at: + continue - # fetch target node from target node id - target_node_config = None - for node in nodes: - if node.get('id') == target_node_id: - target_node_config = node - break + # fetch target node from target node id + target_node_config = None + for node in nodes: + if node.get('id') == target_node_id: + target_node_config = node + break - if not target_node_config: - return None + if not target_node_config: + continue - # get next node - target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) + # get next node + target_node_cls = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type'))) + if not target_node_cls: + continue - return target_node( - tenant_id=workflow_run_state.tenant_id, - app_id=workflow_run_state.app_id, - workflow_id=workflow_run_state.workflow_id, - user_id=workflow_run_state.user_id, - user_from=workflow_run_state.user_from, - invoke_from=workflow_run_state.invoke_from, - config=target_node_config, - callbacks=callbacks, - workflow_call_depth=workflow_run_state.workflow_call_depth - ) - - def _get_node(self, workflow_run_state: WorkflowRunState, - graph: dict, + target_node = target_node_cls( + tenant_id=workflow_run_state.tenant_id, + app_id=workflow_run_state.app_id, + workflow_id=workflow_run_state.workflow_id, + user_id=workflow_run_state.user_id, + user_from=workflow_run_state.user_from, + invoke_from=workflow_run_state.invoke_from, + config=target_node_config, + callbacks=callbacks, + workflow_call_depth=workflow_run_state.workflow_call_depth + ) + + target_nodes.append(target_node) + + return target_nodes + + def _get_node(self, workflow_run_state: WorkflowRunState, + graph: dict, node_id: str, callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]: """ @@ -807,9 +881,6 @@ class WorkflowEngineManager: result=None ) - # add to workflow_nodes_and_results - workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result) - # add steps workflow_run_state.workflow_node_steps += 1 @@ -940,7 +1011,7 @@ class WorkflowEngineManager: return new_value - def _mapping_user_inputs_to_variable_pool(self, + def _mapping_user_inputs_to_variable_pool(self, variable_mapping: dict, user_inputs: dict, variable_pool: VariablePool, @@ -988,4 +1059,4 @@ class WorkflowEngineManager: node_id=variable_node_id, variable_key_list=variable_key_list, value=value - ) \ No newline at end of file + ) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 6235ecf0a3..ce0dba0885 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -8,7 +8,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.workflow_engine_manager import WorkflowEngineManager +from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account @@ -159,8 +159,13 @@ class WorkflowService: Get default block configs """ # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_configs() + default_block_configs = [] + for node_type, node_class in node_classes.items(): + default_config = node_class.get_default_config() + if default_config: + default_block_configs.append(default_config) + + return default_block_configs def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]: """ @@ -169,11 +174,18 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type = NodeType.value_of(node_type) + node_type_enum: NodeType = NodeType.value_of(node_type) # return default block config - workflow_engine_manager = WorkflowEngineManager() - return workflow_engine_manager.get_default_config(node_type, filters) + node_class = node_classes.get(node_type_enum) + if not node_class: + return None + + default_config = node_class.get_default_config(filters=filters) + if not default_config: + return None + + return default_config def run_draft_workflow_node(self, app_model: App, node_id: str,