From 4256e9d47f088137d795f879a20555035fad6152 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 27 Aug 2024 16:38:10 +0800 Subject: [PATCH] chore(iteration): keep start_node_id using in parallel start nodes --- .../nodes/iteration/iteration_node.py | 37 ++----------------- .../nodes/iteration/test_iteration.py | 23 ++++++++++-- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 377fd27664..6cac6af338 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -51,40 +51,11 @@ class IterationNode(BaseNode): } graph_config = self.graph_config + + if not self.node_data.start_node_id: + raise ValueError(f'field start_node_id in iteration {self.node_id} not found') - # find nodes in current iteration and donot have source and have have start_node_in_iteration flag - # these nodes are the start nodes of the iteration (in version of parallel support) - start_node_ids = [] - for node_config in graph_config['nodes']: - if ( - node_config.get('data', {}).get('iteration_id') - and node_config.get('data', {}).get('iteration_id') == self.node_id - and not node_config.get('source') - and node_config.get('data', {}).get('start_node_in_iteration', False) - ): - start_node_ids.append(node_config.get('id')) - - if len(start_node_ids) > 0: - # add new fake iteration start node that connect to all start nodes - root_node_id = f"{self.node_id}-start" - graph_config['nodes'].append({ - "id": root_node_id, - "data": { - "title": "iteration start", - "type": NodeType.ITERATION_START.value, - } - }) - - for start_node_id in start_node_ids: - graph_config['edges'].append({ - "source": root_node_id, - "target": start_node_id - }) - else: - if not self.node_data.start_node_id: - raise ValueError(f'field start_node_id in iteration {self.node_id} not found') - - root_node_id = self.node_data.start_node_id + root_node_id = self.node_data.start_node_id # init graph iteration_graph = Graph.init( diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 89855df6a7..b3a89061b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -214,6 +214,16 @@ def test_run_parallel(): "source": "iteration-1", "target": "answer-3", }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, { "id": "tt-source-if-else-target", "source": "tt", @@ -250,7 +260,7 @@ def test_run_parallel(): "output_selector": ["tt", "output"], "output_type": "array[string]", "startNodeType": "template-transform", - "start_node_id": "tt", + "start_node_id": "iteration-start", "title": "iteration", "type": "iteration", }, @@ -268,7 +278,14 @@ def test_run_parallel(): { "data": { "iteration_id": "iteration-1", - "start_node_in_iteration": True, + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", "template": "{{ arg1 }} 123", "title": "template transform", "type": "template-transform", @@ -279,7 +296,6 @@ def test_run_parallel(): { "data": { "iteration_id": "iteration-1", - "start_node_in_iteration": True, "template": "{{ arg1 }} 321", "title": "template transform", "type": "template-transform", @@ -372,6 +388,7 @@ def test_run_parallel(): "output_selector": ["tt", "output"], "output_type": "array[string]", "startNodeType": "template-transform", + "start_node_id": "iteration-start", "title": "迭代", "type": "iteration", },