chore(iteration): keep start_node_id using in parallel start nodes

This commit is contained in:
takatost 2024-08-27 16:38:10 +08:00
parent 4e3dc36e37
commit 4256e9d47f
2 changed files with 24 additions and 36 deletions

View File

@ -51,40 +51,11 @@ class IterationNode(BaseNode):
}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
# find nodes in current iteration and donot have source and have have start_node_in_iteration flag
# these nodes are the start nodes of the iteration (in version of parallel support)
start_node_ids = []
for node_config in graph_config['nodes']:
if (
node_config.get('data', {}).get('iteration_id')
and node_config.get('data', {}).get('iteration_id') == self.node_id
and not node_config.get('source')
and node_config.get('data', {}).get('start_node_in_iteration', False)
):
start_node_ids.append(node_config.get('id'))
if len(start_node_ids) > 0:
# add new fake iteration start node that connect to all start nodes
root_node_id = f"{self.node_id}-start"
graph_config['nodes'].append({
"id": root_node_id,
"data": {
"title": "iteration start",
"type": NodeType.ITERATION_START.value,
}
})
for start_node_id in start_node_ids:
graph_config['edges'].append({
"source": root_node_id,
"target": start_node_id
})
else:
if not self.node_data.start_node_id:
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
root_node_id = self.node_data.start_node_id
root_node_id = self.node_data.start_node_id
# init graph
iteration_graph = Graph.init(

View File

@ -214,6 +214,16 @@ def test_run_parallel():
"source": "iteration-1",
"target": "answer-3",
},
{
"id": "iteration-start-source-tt-target",
"source": "iteration-start",
"target": "tt",
},
{
"id": "iteration-start-source-tt-2-target",
"source": "iteration-start",
"target": "tt-2",
},
{
"id": "tt-source-if-else-target",
"source": "tt",
@ -250,7 +260,7 @@ def test_run_parallel():
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
},
@ -268,7 +278,14 @@ def test_run_parallel():
{
"data": {
"iteration_id": "iteration-1",
"start_node_in_iteration": True,
"title": "iteration-start",
"type": "iteration-start",
},
"id": "iteration-start",
},
{
"data": {
"iteration_id": "iteration-1",
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
@ -279,7 +296,6 @@ def test_run_parallel():
{
"data": {
"iteration_id": "iteration-1",
"start_node_in_iteration": True,
"template": "{{ arg1 }} 321",
"title": "template transform",
"type": "template-transform",
@ -372,6 +388,7 @@ def test_run_parallel():
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},