mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 04:25:54 +08:00
add new graph structure
This commit is contained in:
parent
c5d64baba4
commit
8217c46116
@ -103,6 +103,7 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.QUERY: query,
|
||||
@ -110,7 +111,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
SystemVariable.CONVERSATION_ID: conversation.id,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth
|
||||
)
|
||||
|
||||
|
@ -74,12 +74,12 @@ class WorkflowAppRunner:
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
user_inputs=inputs,
|
||||
system_inputs={
|
||||
SystemVariable.FILES: files,
|
||||
SystemVariable.USER_ID: user_id
|
||||
},
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth
|
||||
)
|
||||
|
||||
|
@ -66,8 +66,7 @@ class WorkflowRunState:
|
||||
self.variable_pool = variable_pool
|
||||
|
||||
self.total_tokens = 0
|
||||
self.workflow_nodes_and_results = []
|
||||
|
||||
self.current_iteration_state = None
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
||||
self.current_iteration_state = None
|
||||
|
164
api/core/workflow/graph.py
Normal file
164
api/core/workflow/graph.py
Normal file
@ -0,0 +1,164 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
id: str
|
||||
"""node id"""
|
||||
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
children_node_ids: list[str] = []
|
||||
"""children node ids"""
|
||||
|
||||
source_handle: Optional[str] = None
|
||||
"""current node source handle from the previous node result"""
|
||||
|
||||
is_continue_callback: Optional[Callable] = None
|
||||
"""condition function check if the node can be executed"""
|
||||
|
||||
node_config: dict
|
||||
"""original node config"""
|
||||
|
||||
source_edge_config: Optional[dict] = None
|
||||
"""original source edge config"""
|
||||
|
||||
target_edge_config: Optional[dict] = None
|
||||
"""original target edge config"""
|
||||
|
||||
sub_graph: Optional["Graph"] = None
|
||||
"""sub graph for iteration or loop node"""
|
||||
|
||||
def add_child(self, node_id: str) -> None:
|
||||
self.children_node_ids.append(node_id)
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
graph_nodes: dict[str, GraphNode] = {}
|
||||
"""graph nodes"""
|
||||
|
||||
root_node: Optional[GraphNode] = None
|
||||
"""root node of the graph"""
|
||||
|
||||
def add_edge(self, edge_config: dict,
|
||||
source_node_config: dict,
|
||||
target_node_config: dict,
|
||||
source_node_sub_graph: Optional["Graph"] = None,
|
||||
is_continue_callback: Optional[Callable] = None) -> None:
|
||||
"""
|
||||
Add edge to the graph
|
||||
|
||||
:param edge_config: edge config
|
||||
:param source_node_config: source node config
|
||||
:param target_node_config: target node config
|
||||
:param source_node_sub_graph: sub graph for iteration or loop node
|
||||
:param is_continue_callback: condition callback
|
||||
"""
|
||||
source_node_id = source_node_config.get('id')
|
||||
if not source_node_id:
|
||||
return
|
||||
|
||||
target_node_id = target_node_config.get('id')
|
||||
if not target_node_id:
|
||||
return
|
||||
|
||||
if source_node_id not in self.graph_nodes:
|
||||
source_graph_node = GraphNode(
|
||||
id=source_node_id,
|
||||
node_config=source_node_config,
|
||||
children_node_ids=[target_node_id],
|
||||
target_edge_config=edge_config,
|
||||
sub_graph=source_node_sub_graph
|
||||
)
|
||||
|
||||
self.add_graph_node(source_graph_node)
|
||||
else:
|
||||
source_node = self.graph_nodes[source_node_id]
|
||||
source_node.add_child(target_node_id)
|
||||
source_node.target_edge_config = edge_config
|
||||
source_node.sub_graph = source_node_sub_graph
|
||||
|
||||
source_handle = None
|
||||
if edge_config.get('sourceHandle'):
|
||||
source_handle = edge_config.get('sourceHandle')
|
||||
|
||||
if target_node_id not in self.graph_nodes:
|
||||
target_graph_node = GraphNode(
|
||||
id=target_node_id,
|
||||
predecessor_node_id=source_node_id,
|
||||
node_config=target_node_config,
|
||||
source_handle=source_handle,
|
||||
is_continue_callback=is_continue_callback,
|
||||
source_edge_config=edge_config,
|
||||
)
|
||||
|
||||
self.add_graph_node(target_graph_node)
|
||||
else:
|
||||
target_node = self.graph_nodes[target_node_id]
|
||||
target_node.predecessor_node_id = source_node_id
|
||||
target_node.source_handle = source_handle
|
||||
target_node.is_continue_callback = is_continue_callback
|
||||
target_node.source_edge_config = edge_config
|
||||
|
||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
"""
|
||||
Add graph node to the graph
|
||||
|
||||
:param graph_node: graph node
|
||||
"""
|
||||
if graph_node.id in self.graph_nodes:
|
||||
return
|
||||
|
||||
if len(self.graph_nodes) == 0:
|
||||
self.root_node = graph_node
|
||||
|
||||
self.graph_nodes[graph_node.id] = graph_node
|
||||
|
||||
def get_root_node(self) -> Optional[GraphNode]:
|
||||
"""
|
||||
Get root node of the graph
|
||||
|
||||
:return: root node
|
||||
"""
|
||||
return self.root_node
|
||||
|
||||
def get_descendants_graph(self, node_id: str) -> Optional["Graph"]:
|
||||
"""
|
||||
Get descendants graph of the specific node
|
||||
|
||||
:param node_id: node id
|
||||
:return: descendants graph
|
||||
"""
|
||||
if node_id not in self.graph_nodes:
|
||||
return None
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
if not graph_node.children_node_ids:
|
||||
return None
|
||||
|
||||
descendants_graph = Graph()
|
||||
descendants_graph.add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.children_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
|
||||
return descendants_graph
|
||||
|
||||
def _add_descendants_graph_nodes(self, descendants_graph: "Graph", node_id: str) -> None:
|
||||
"""
|
||||
Add descendants graph nodes
|
||||
|
||||
:param descendants_graph: descendants graph
|
||||
:param node_id: node id
|
||||
"""
|
||||
if node_id not in self.graph_nodes:
|
||||
return
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
descendants_graph.add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.children_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import current_app
|
||||
|
||||
@ -9,7 +10,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@ -49,7 +50,7 @@ node_classes = {
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||
}
|
||||
@ -58,65 +59,38 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEngineManager:
|
||||
def get_default_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configs
|
||||
"""
|
||||
default_block_configs = []
|
||||
for node_type, node_class in node_classes.items():
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(default_config)
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_config(self, node_type: NodeType, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_class = node_classes.get(node_type)
|
||||
if not node_class:
|
||||
return None
|
||||
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
def run_workflow(self, workflow: Workflow,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
call_depth: Optional[int] = 0,
|
||||
system_inputs: dict[SystemVariable, Any],
|
||||
call_depth: int = 0,
|
||||
variable_pool: Optional[VariablePool] = None) -> None:
|
||||
"""
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from service-api, web-app, debugger, explore
|
||||
:param callbacks: workflow callbacks
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param callbacks: workflow callbacks
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph = workflow.graph_dict
|
||||
if not graph:
|
||||
graph_dict = workflow.graph_dict
|
||||
if not graph_dict:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
if 'nodes' not in graph or 'edges' not in graph:
|
||||
if 'nodes' not in graph_dict or 'edges' not in graph_dict:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph.get('nodes'), list):
|
||||
if not isinstance(graph_dict.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph.get('edges'), list):
|
||||
if not isinstance(graph_dict.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# init variable pool
|
||||
@ -126,7 +100,9 @@ class WorkflowEngineManager:
|
||||
user_inputs=user_inputs
|
||||
)
|
||||
|
||||
# fetch max call depth
|
||||
workflow_call_max_depth = current_app.config.get("WORKFLOW_CALL_MAX_DEPTH")
|
||||
workflow_call_max_depth = cast(int, workflow_call_max_depth)
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
|
||||
@ -142,55 +118,55 @@ class WorkflowEngineManager:
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started()
|
||||
self._workflow_run_started(
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
workflow=workflow,
|
||||
graph=graph_dict,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def _run_workflow(self, workflow: Workflow,
|
||||
# workflow run success
|
||||
self._workflow_run_success(
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _run_workflow(self, graph: dict,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> None:
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
start_node: Optional[str] = None,
|
||||
end_node: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run workflow
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param graph: workflow graph
|
||||
:param workflow_run_state: workflow run state
|
||||
:param callbacks: workflow callbacks
|
||||
:param call_depth: call depth
|
||||
:param start_at: force specific start node
|
||||
:param end_at: force specific end node
|
||||
:param start_node: force specific start node (gte)
|
||||
:param end_node: force specific end node (le)
|
||||
:return:
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
|
||||
try:
|
||||
predecessor_node: BaseNode = None
|
||||
current_iteration_node: BaseIterationNode = None
|
||||
has_entry_node = False
|
||||
predecessor_node: Optional[BaseNode] = None
|
||||
current_iteration_node: Optional[BaseIterationNode] = None
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
max_execution_steps = cast(int, max_execution_steps)
|
||||
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||
max_execution_time = cast(int, max_execution_time)
|
||||
while True:
|
||||
# get next node, multiple target nodes in the future
|
||||
next_node = self._get_next_overall_node(
|
||||
# get next nodes
|
||||
next_nodes = self._get_next_overall_nodes(
|
||||
workflow_run_state=workflow_run_state,
|
||||
graph=graph,
|
||||
predecessor_node=predecessor_node,
|
||||
callbacks=callbacks,
|
||||
start_at=start_at,
|
||||
end_at=end_at
|
||||
node_start_at=start_node,
|
||||
node_end_at=end_node
|
||||
)
|
||||
|
||||
if not next_node:
|
||||
if not next_nodes:
|
||||
# reached loop/iteration end or overall end
|
||||
if current_iteration_node and workflow_run_state.current_iteration_state:
|
||||
# reached loop/iteration end
|
||||
@ -221,13 +197,13 @@ class WorkflowEngineManager:
|
||||
callbacks=callbacks
|
||||
)
|
||||
# iteration has ended
|
||||
next_node = self._get_next_overall_node(
|
||||
next_nodes = self._get_next_overall_nodes(
|
||||
workflow_run_state=workflow_run_state,
|
||||
graph=graph,
|
||||
predecessor_node=current_iteration_node,
|
||||
callbacks=callbacks,
|
||||
start_at=start_at,
|
||||
end_at=end_at
|
||||
node_start_at=start_node,
|
||||
node_end_at=end_node
|
||||
)
|
||||
current_iteration_node = None
|
||||
workflow_run_state.current_iteration_state = None
|
||||
@ -236,18 +212,11 @@ class WorkflowEngineManager:
|
||||
# move to next iteration
|
||||
next_node_id = next_iteration
|
||||
# get next id
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
next_nodes = [self._get_node(workflow_run_state, graph, next_node_id, callbacks)]
|
||||
|
||||
if not next_node:
|
||||
if not next_nodes:
|
||||
break
|
||||
|
||||
# check is already ran
|
||||
if self._check_node_has_ran(workflow_run_state, next_node.node_id):
|
||||
predecessor_node = next_node
|
||||
continue
|
||||
|
||||
has_entry_node = True
|
||||
|
||||
# max steps reached
|
||||
if workflow_run_state.workflow_node_steps > max_execution_steps:
|
||||
raise ValueError('Max steps {} reached.'.format(max_execution_steps))
|
||||
@ -256,62 +225,40 @@ class WorkflowEngineManager:
|
||||
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
|
||||
raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
|
||||
|
||||
# handle iteration nodes
|
||||
if isinstance(next_node, BaseIterationNode):
|
||||
current_iteration_node = next_node
|
||||
workflow_run_state.current_iteration_state = next_node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
self._workflow_iteration_started(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
|
||||
callbacks=callbacks
|
||||
)
|
||||
predecessor_node = next_node
|
||||
# move to start node of iteration
|
||||
next_node_id = next_node.get_next_iteration(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_next(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
if isinstance(next_node_id, NodeRunResult):
|
||||
# iteration has ended
|
||||
current_iteration_node.set_output(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_completed(
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
current_iteration_node = None
|
||||
workflow_run_state.current_iteration_state = None
|
||||
continue
|
||||
else:
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
if len(next_nodes) == 1:
|
||||
next_node = next_nodes[0]
|
||||
|
||||
# run workflow, run multiple target nodes in the future
|
||||
self._run_workflow_node(
|
||||
# run node
|
||||
is_continue = self._run_node(
|
||||
graph=graph,
|
||||
workflow_run_state=workflow_run_state,
|
||||
node=next_node,
|
||||
predecessor_node=predecessor_node,
|
||||
current_node=next_node,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if next_node.node_type in [NodeType.END]:
|
||||
if not is_continue:
|
||||
break
|
||||
|
||||
predecessor_node = next_node
|
||||
else:
|
||||
result_dict = {}
|
||||
|
||||
if not has_entry_node:
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'graph': graph,
|
||||
'workflow_run_state': workflow_run_state,
|
||||
'predecessor_node': predecessor_node,
|
||||
'next_nodes': next_nodes,
|
||||
'callbacks': callbacks,
|
||||
'result': result_dict
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread.join()
|
||||
|
||||
if not workflow_run_state.workflow_node_runs:
|
||||
self._workflow_run_failed(
|
||||
error='Start node not found in workflow graph.',
|
||||
callbacks=callbacks
|
||||
@ -326,11 +273,109 @@ class WorkflowEngineManager:
|
||||
)
|
||||
return
|
||||
|
||||
# workflow run success
|
||||
self._workflow_run_success(
|
||||
def _async_run_nodes(self, flask_app: Flask,
|
||||
graph: dict,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
predecessor_node: Optional[BaseNode],
|
||||
next_nodes: list[BaseNode],
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
result: dict):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
for next_node in next_nodes:
|
||||
# TODO run sub workflows
|
||||
# run node
|
||||
is_continue = self._run_node(
|
||||
graph=graph,
|
||||
workflow_run_state=workflow_run_state,
|
||||
predecessor_node=predecessor_node,
|
||||
current_node=next_node,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if not is_continue:
|
||||
break
|
||||
|
||||
predecessor_node = next_node
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self, graph: dict,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
predecessor_node: Optional[BaseNode],
|
||||
current_node: BaseNode,
|
||||
callbacks: list[BaseWorkflowCallback]) -> bool:
|
||||
"""
|
||||
Run node
|
||||
:param graph: workflow graph
|
||||
:param workflow_run_state: current workflow run state
|
||||
:param predecessor_node: predecessor node
|
||||
:param current_node: current node for run
|
||||
:param callbacks: workflow callbacks
|
||||
:return: continue?
|
||||
"""
|
||||
# check is already ran
|
||||
if self._check_node_has_ran(workflow_run_state, current_node.node_id):
|
||||
return True
|
||||
|
||||
# handle iteration nodes
|
||||
if isinstance(current_node, BaseIterationNode):
|
||||
current_iteration_node = current_node
|
||||
workflow_run_state.current_iteration_state = current_node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
self._workflow_iteration_started(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
|
||||
callbacks=callbacks
|
||||
)
|
||||
predecessor_node = current_node
|
||||
# move to start node of iteration
|
||||
current_node_id = current_node.get_next_iteration(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_next(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
if isinstance(current_node_id, NodeRunResult):
|
||||
# iteration has ended
|
||||
current_iteration_node.set_output(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_completed(
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
current_iteration_node = None
|
||||
workflow_run_state.current_iteration_state = None
|
||||
return True
|
||||
else:
|
||||
# fetch next node in iteration
|
||||
current_node = self._get_node(workflow_run_state, graph, current_node_id, callbacks)
|
||||
|
||||
# run workflow, run multiple target nodes in the future
|
||||
self._run_workflow_node(
|
||||
workflow_run_state=workflow_run_state,
|
||||
node=current_node,
|
||||
predecessor_node=predecessor_node,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if current_node.node_type in [NodeType.END]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def single_step_run_workflow_node(self, workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
@ -529,13 +574,29 @@ class WorkflowEngineManager:
|
||||
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
workflow=workflow,
|
||||
graph=workflow.graph,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks,
|
||||
start_at=node_id,
|
||||
end_at=end_node_id
|
||||
start_node=node_id,
|
||||
end_node=end_node_id
|
||||
)
|
||||
|
||||
# workflow run success
|
||||
self._workflow_run_success(
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _workflow_run_started(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
# init workflow run
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started()
|
||||
|
||||
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
"""
|
||||
Workflow run success
|
||||
@ -645,35 +706,39 @@ class WorkflowEngineManager:
|
||||
}
|
||||
)
|
||||
|
||||
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
|
||||
def _get_next_overall_nodes(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> Optional[BaseNode]:
|
||||
node_start_at: Optional[str] = None,
|
||||
node_end_at: Optional[str] = None) -> list[BaseNode]:
|
||||
"""
|
||||
Get next node
|
||||
Get next nodes
|
||||
multiple target nodes in the future.
|
||||
:param graph: workflow graph
|
||||
:param predecessor_node: predecessor node
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
:param predecessor_node: predecessor node
|
||||
:param node_start_at: force specific start node
|
||||
:param node_end_at: force specific end node
|
||||
:return: target node list
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
return None
|
||||
return []
|
||||
|
||||
if not predecessor_node:
|
||||
# fetch start node
|
||||
for node_config in nodes:
|
||||
node_cls = None
|
||||
if start_at:
|
||||
if node_config.get('id') == start_at:
|
||||
if node_start_at:
|
||||
if node_config.get('id') == node_start_at:
|
||||
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
|
||||
else:
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
|
||||
node_cls = StartNode
|
||||
|
||||
if node_cls:
|
||||
return node_cls(
|
||||
return [node_cls(
|
||||
tenant_id=workflow_run_state.tenant_id,
|
||||
app_id=workflow_run_state.app_id,
|
||||
workflow_id=workflow_run_state.workflow_id,
|
||||
@ -683,36 +748,39 @@ class WorkflowEngineManager:
|
||||
config=node_config,
|
||||
callbacks=callbacks,
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
)]
|
||||
|
||||
return []
|
||||
else:
|
||||
edges = graph.get('edges')
|
||||
edges = cast(list, edges)
|
||||
source_node_id = predecessor_node.node_id
|
||||
|
||||
# fetch all outgoing edges from source node
|
||||
outgoing_edges = [edge for edge in edges if edge.get('source') == source_node_id]
|
||||
if not outgoing_edges:
|
||||
return None
|
||||
return []
|
||||
|
||||
# fetch target node id from outgoing edges
|
||||
outgoing_edge = None
|
||||
# fetch target node ids from outgoing edges
|
||||
target_edges = []
|
||||
source_handle = predecessor_node.node_run_result.edge_source_handle \
|
||||
if predecessor_node.node_run_result else None
|
||||
if source_handle:
|
||||
for edge in outgoing_edges:
|
||||
if edge.get('sourceHandle') and edge.get('sourceHandle') == source_handle:
|
||||
outgoing_edge = edge
|
||||
break
|
||||
target_edges.append(edge)
|
||||
else:
|
||||
outgoing_edge = outgoing_edges[0]
|
||||
target_edges = outgoing_edges
|
||||
|
||||
if not outgoing_edge:
|
||||
return None
|
||||
if not target_edges:
|
||||
return []
|
||||
|
||||
target_node_id = outgoing_edge.get('target')
|
||||
target_nodes = []
|
||||
for target_edge in target_edges:
|
||||
target_node_id = target_edge.get('target')
|
||||
|
||||
if end_at and target_node_id == end_at:
|
||||
return None
|
||||
if node_end_at and target_node_id == node_end_at:
|
||||
continue
|
||||
|
||||
# fetch target node from target node id
|
||||
target_node_config = None
|
||||
@ -722,12 +790,14 @@ class WorkflowEngineManager:
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
return None
|
||||
continue
|
||||
|
||||
# get next node
|
||||
target_node = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
|
||||
target_node_cls = node_classes.get(NodeType.value_of(target_node_config.get('data', {}).get('type')))
|
||||
if not target_node_cls:
|
||||
continue
|
||||
|
||||
return target_node(
|
||||
target_node = target_node_cls(
|
||||
tenant_id=workflow_run_state.tenant_id,
|
||||
app_id=workflow_run_state.app_id,
|
||||
workflow_id=workflow_run_state.workflow_id,
|
||||
@ -739,6 +809,10 @@ class WorkflowEngineManager:
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
|
||||
target_nodes.append(target_node)
|
||||
|
||||
return target_nodes
|
||||
|
||||
def _get_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
node_id: str,
|
||||
@ -807,9 +881,6 @@ class WorkflowEngineManager:
|
||||
result=None
|
||||
)
|
||||
|
||||
# add to workflow_nodes_and_results
|
||||
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
|
||||
|
||||
# add steps
|
||||
workflow_run_state.workflow_node_steps += 1
|
||||
|
||||
|
@ -8,7 +8,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
@ -159,8 +159,13 @@ class WorkflowService:
|
||||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_configs()
|
||||
default_block_configs = []
|
||||
for node_type, node_class in node_classes.items():
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(default_config)
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
|
||||
"""
|
||||
@ -169,11 +174,18 @@ class WorkflowService:
|
||||
:param filters: filter by node config parameters.
|
||||
:return:
|
||||
"""
|
||||
node_type = NodeType.value_of(node_type)
|
||||
node_type_enum: NodeType = NodeType.value_of(node_type)
|
||||
|
||||
# return default block config
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_config(node_type, filters)
|
||||
node_class = node_classes.get(node_type_enum)
|
||||
if not node_class:
|
||||
return None
|
||||
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
def run_draft_workflow_node(self, app_model: App,
|
||||
node_id: str,
|
||||
|
Loading…
x
Reference in New Issue
Block a user