diff --git a/api/core/workflow/entities/base_node_data_entities.py b/api/core/workflow/entities/base_node_data_entities.py index 6bf0c11c7d..e7e6710cbd 100644 --- a/api/core/workflow/entities/base_node_data_entities.py +++ b/api/core/workflow/entities/base_node_data_entities.py @@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel): desc: Optional[str] = None class BaseIterationNodeData(BaseNodeData): - start_node_id: str + start_node_id: Optional[str] = None class BaseIterationState(BaseModel): iteration_node_id: str diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index f4e90864fc..5e2a5cb466 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -28,6 +28,7 @@ class NodeType(Enum): VARIABLE_ASSIGNER = 'variable-assigner' LOOP = 'loop' ITERATION = 'iteration' + ITERATION_START = 'iteration-start' # fake start node for iteration PARAMETER_EXTRACTOR = 'parameter-extractor' CONVERSATION_VARIABLE_ASSIGNER = 'assigner' diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 177b47b951..5fc5a827ae 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,6 +1,6 @@ from typing import Any, Optional -from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState +from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData class IterationNodeData(BaseIterationNodeData): @@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData): iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + +class IterationStartNodeData(BaseNodeData): + """ + Iteration Start Node Data. + """ + pass + class IterationState(BaseIterationState): """ Iteration State. diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index e12ae1c71a..a988bcf06d 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -50,9 +50,39 @@ class IterationNode(BaseNode): "iterator_selector": iterator_list_value } - root_node_id = self.node_data.start_node_id graph_config = self.graph_config + # find nodes in current iteration and donot have source and have have start_node_in_iteration flag + # these nodes are the start nodes of the iteration (in version of parallel support) + start_node_ids = [] + for node_config in graph_config['nodes']: + if ( + node_config.get('data', {}).get('iteration_id') + and node_config.get('data', {}).get('iteration_id') == self.node_id + and not node_config.get('source') + and node_config.get('data', {}).get('start_node_in_iteration', False) + ): + start_node_ids.append(node_config.get('id')) + + if len(start_node_ids) > 1: + # add new fake iteration start node that connect to all start nodes + root_node_id = f"{self.node_id}-start" + graph_config['nodes'].append({ + "id": root_node_id, + "data": { + "title": "iteration start", + "type": NodeType.ITERATION_START.value, + } + }) + + for start_node_id in start_node_ids: + graph_config['edges'].append({ + "source": root_node_id, + "target": start_node_id + }) + else: + root_node_id = self.node_data.start_node_id + # init graph iteration_graph = Graph.init( graph_config=graph_config, @@ -156,6 +186,9 @@ class IterationNode(BaseNode): if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: event.in_iteration_id = self.node_id + if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START: + continue + if isinstance(event, NodeRunSucceededEvent): if event.route_node_state.node_run_result: metadata = event.route_node_state.node_run_result.metadata @@ -180,7 +213,11 @@ class IterationNode(BaseNode): variable_pool.remove_node(node_id) # move to next iteration - next_index = variable_pool.get_any([self.node_id, 'index']) + 1 + current_index = variable_pool.get([self.node_id, 'index']) + if current_index is None: + raise ValueError(f'iteration {self.node_id} current index not found') + + next_index = int(current_index.to_object()) + 1 variable_pool.add( [self.node_id, 'index'], next_index @@ -229,6 +266,7 @@ class IterationNode(BaseNode): ) break else: + event = cast(InNodeEvent, event) yield event yield IterationRunSucceededEvent( diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py new file mode 100644 index 0000000000..2aee241bee --- /dev/null +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -0,0 +1,40 @@ +import logging +from collections.abc import Mapping, Sequence +from typing import Any + +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class IterationStartNode(BaseNode): + """ + Iteration Start Node. + """ + _node_data_cls = IterationStartNodeData + _node_type = NodeType.ITERATION_START + + def _run(self) -> NodeRunResult: + """ + Run the node. + """ + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping( + cls, + graph_config: Mapping[str, Any], + node_id: str, + node_data: IterationNodeData + ) -> Mapping[str, Sequence[str]]: + """ + Extract variable selector to variable mapping + :param graph_config: graph config + :param node_id: node id + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index 87e4d4be18..b98525e86e 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -5,6 +5,7 @@ from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode @@ -30,6 +31,7 @@ node_classes = { NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR NodeType.ITERATION: IterationNode, + NodeType.ITERATION_START: IterationStartNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, } diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index ff46e62d1f..344553f344 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -210,3 +210,217 @@ def test_run(): assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} assert count == 20 + + +def test_run_parallel(): + graph_config = { + "edges": [{ + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }], + "nodes": [{ + "data": { + "title": "Start", + "type": "start", + "variables": [] + }, + "id": "start" + }, { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "tt", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer" + }, + "id": "answer-2" + }, { + "data": { + "iteration_id": "iteration-1", + "start_node_in_iteration": True, + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{ + "value_selector": ["sys", "query"], + "variable": "arg1" + }] + }, + "id": "tt", + },{ + "data": { + "iteration_id": "iteration-1", + "start_node_in_iteration": True, + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{ + "value_selector": ["sys", "query"], + "variable": "arg1" + }] + }, + "id": "tt-2", + }, { + "data": { + "answer": "{{#iteration-1.output#}}88888", + "title": "answer 3", + "type": "answer" + }, + "id": "answer-3", + }, { + "data": { + "conditions": [{ + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"] + }], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else" + }, + "id": "if-else", + }, { + "data": { + "answer": "no hi", + "iteration_id": "iteration-1", + "title": "answer 4", + "type": "answer" + }, + "id": "answer-4", + }, { + "data": { + "instruction": "test1", + "model": { + "completion_params": { + "temperature": 0.7 + }, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai" + }, + "parameters": [{ + "description": "test", + "name": "list_output", + "required": False, + "type": "array[string]" + }], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor" + }, + "id": "pe", + }] + } + + graph = Graph.init( + graph_config=graph_config + ) + + init_params = GraphInitParams( + tenant_id='1', + app_id='1', + workflow_type=WorkflowType.CHAT, + workflow_id='1', + graph_config=graph_config, + user_id='1', + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariableKey.QUERY: 'dify', + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: 'abababa', + SystemVariableKey.USER_ID: '1' + }, user_inputs={}, environment_variables=[]) + pool.add(['pe', 'list_output'], ["dify-1", "dify-2"]) + + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + start_at=time.perf_counter() + ), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "title": "迭代", + "type": "iteration", + }, + "id": "iteration-1", + } + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + 'iterator_selector': 'dify' + }, + outputs={ + 'output': 'dify 123' + } + ) + + # print("") + + with patch.object(TemplateTransformNode, '_run', new=tt_generator): + # execute node + result = iteration_node._run() + + count = 0 + for item in result: + # print(type(item), item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + + assert count == 32