fix iteration start node

This commit is contained in:
takatost 2024-08-22 23:53:44 +08:00
parent d6da7b0336
commit ec4fc784f0
7 changed files with 306 additions and 4 deletions

View File

@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
desc: Optional[str] = None desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData): class BaseIterationNodeData(BaseNodeData):
start_node_id: str start_node_id: Optional[str] = None
class BaseIterationState(BaseModel): class BaseIterationState(BaseModel):
iteration_node_id: str iteration_node_id: str

View File

@ -28,6 +28,7 @@ class NodeType(Enum):
VARIABLE_ASSIGNER = 'variable-assigner' VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop' LOOP = 'loop'
ITERATION = 'iteration' ITERATION = 'iteration'
ITERATION_START = 'iteration-start' # fake start node for iteration
PARAMETER_EXTRACTOR = 'parameter-extractor' PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner' CONVERSATION_VARIABLE_ASSIGNER = 'assigner'

View File

@ -1,6 +1,6 @@
from typing import Any, Optional from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
class IterationNodeData(BaseIterationNodeData): class IterationNodeData(BaseIterationNodeData):
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
iterator_selector: list[str] # variable selector iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector output_selector: list[str] # output selector
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState): class IterationState(BaseIterationState):
""" """
Iteration State. Iteration State.

View File

@ -50,9 +50,39 @@ class IterationNode(BaseNode):
"iterator_selector": iterator_list_value "iterator_selector": iterator_list_value
} }
root_node_id = self.node_data.start_node_id
graph_config = self.graph_config graph_config = self.graph_config
# find nodes in current iteration and donot have source and have have start_node_in_iteration flag
# these nodes are the start nodes of the iteration (in version of parallel support)
start_node_ids = []
for node_config in graph_config['nodes']:
if (
node_config.get('data', {}).get('iteration_id')
and node_config.get('data', {}).get('iteration_id') == self.node_id
and not node_config.get('source')
and node_config.get('data', {}).get('start_node_in_iteration', False)
):
start_node_ids.append(node_config.get('id'))
if len(start_node_ids) > 1:
# add new fake iteration start node that connect to all start nodes
root_node_id = f"{self.node_id}-start"
graph_config['nodes'].append({
"id": root_node_id,
"data": {
"title": "iteration start",
"type": NodeType.ITERATION_START.value,
}
})
for start_node_id in start_node_ids:
graph_config['edges'].append({
"source": root_node_id,
"target": start_node_id
})
else:
root_node_id = self.node_data.start_node_id
# init graph # init graph
iteration_graph = Graph.init( iteration_graph = Graph.init(
graph_config=graph_config, graph_config=graph_config,
@ -156,6 +186,9 @@ class IterationNode(BaseNode):
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id event.in_iteration_id = self.node_id
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
continue
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result: if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata metadata = event.route_node_state.node_run_result.metadata
@ -180,7 +213,11 @@ class IterationNode(BaseNode):
variable_pool.remove_node(node_id) variable_pool.remove_node(node_id)
# move to next iteration # move to next iteration
next_index = variable_pool.get_any([self.node_id, 'index']) + 1 current_index = variable_pool.get([self.node_id, 'index'])
if current_index is None:
raise ValueError(f'iteration {self.node_id} current index not found')
next_index = int(current_index.to_object()) + 1
variable_pool.add( variable_pool.add(
[self.node_id, 'index'], [self.node_id, 'index'],
next_index next_index
@ -229,6 +266,7 @@ class IterationNode(BaseNode):
) )
break break
else: else:
event = cast(InNodeEvent, event)
yield event yield event
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(

View File

@ -0,0 +1,40 @@
import logging
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -5,6 +5,7 @@ from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
@ -30,6 +31,7 @@ node_classes = {
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode, NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
} }

View File

@ -210,3 +210,217 @@ def test_run():
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert count == 20 assert count == 20
def test_run_parallel():
graph_config = {
"edges": [{
"id": "start-source-pe-target",
"source": "start",
"target": "pe",
}, {
"id": "iteration-1-source-answer-3-target",
"source": "iteration-1",
"target": "answer-3",
}, {
"id": "tt-source-if-else-target",
"source": "tt",
"target": "if-else",
}, {
"id": "tt-2-source-if-else-target",
"source": "tt-2",
"target": "if-else",
}, {
"id": "if-else-true-answer-2-target",
"source": "if-else",
"sourceHandle": "true",
"target": "answer-2",
}, {
"id": "if-else-false-answer-4-target",
"source": "if-else",
"sourceHandle": "false",
"target": "answer-4",
}, {
"id": "pe-source-iteration-1-target",
"source": "pe",
"target": "iteration-1",
}],
"nodes": [{
"data": {
"title": "Start",
"type": "start",
"variables": []
},
"id": "start"
}, {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "iteration",
"type": "iteration",
},
"id": "iteration-1",
}, {
"data": {
"answer": "{{#tt.output#}}",
"iteration_id": "iteration-1",
"title": "answer 2",
"type": "answer"
},
"id": "answer-2"
}, {
"data": {
"iteration_id": "iteration-1",
"start_node_in_iteration": True,
"template": "{{ arg1 }} 123",
"title": "template transform",
"type": "template-transform",
"variables": [{
"value_selector": ["sys", "query"],
"variable": "arg1"
}]
},
"id": "tt",
},{
"data": {
"iteration_id": "iteration-1",
"start_node_in_iteration": True,
"template": "{{ arg1 }} 321",
"title": "template transform",
"type": "template-transform",
"variables": [{
"value_selector": ["sys", "query"],
"variable": "arg1"
}]
},
"id": "tt-2",
}, {
"data": {
"answer": "{{#iteration-1.output#}}88888",
"title": "answer 3",
"type": "answer"
},
"id": "answer-3",
}, {
"data": {
"conditions": [{
"comparison_operator": "is",
"id": "1721916275284",
"value": "hi",
"variable_selector": ["sys", "query"]
}],
"iteration_id": "iteration-1",
"logical_operator": "and",
"title": "if",
"type": "if-else"
},
"id": "if-else",
}, {
"data": {
"answer": "no hi",
"iteration_id": "iteration-1",
"title": "answer 4",
"type": "answer"
},
"id": "answer-4",
}, {
"data": {
"instruction": "test1",
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"parameters": [{
"description": "test",
"name": "list_output",
"required": False,
"type": "array[string]"
}],
"query": ["sys", "query"],
"reasoning_mode": "prompt",
"title": "pe",
"type": "parameter-extractor"
},
"id": "pe",
}]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.CHAT,
workflow_id='1',
graph_config=graph_config,
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.QUERY: 'dify',
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: 'abababa',
SystemVariableKey.USER_ID: '1'
}, user_inputs={}, environment_variables=[])
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
start_at=time.perf_counter()
),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
)
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
'iterator_selector': 'dify'
},
outputs={
'output': 'dify 123'
}
)
# print("")
with patch.object(TemplateTransformNode, '_run', new=tt_generator):
# execute node
result = iteration_node._run()
count = 0
for item in result:
# print(type(item), item)
count += 1
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
assert count == 32