mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 05:55:53 +08:00
fix iteration start node
This commit is contained in:
parent
d6da7b0336
commit
ec4fc784f0
@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
|
||||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
|
@ -28,6 +28,7 @@ class NodeType(Enum):
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
ITERATION_START = 'iteration-start' # fake start node for iteration
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Iteration Start Node Data.
|
||||
"""
|
||||
pass
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
|
@ -50,9 +50,39 @@ class IterationNode(BaseNode):
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
graph_config = self.graph_config
|
||||
|
||||
# find nodes in current iteration and donot have source and have have start_node_in_iteration flag
|
||||
# these nodes are the start nodes of the iteration (in version of parallel support)
|
||||
start_node_ids = []
|
||||
for node_config in graph_config['nodes']:
|
||||
if (
|
||||
node_config.get('data', {}).get('iteration_id')
|
||||
and node_config.get('data', {}).get('iteration_id') == self.node_id
|
||||
and not node_config.get('source')
|
||||
and node_config.get('data', {}).get('start_node_in_iteration', False)
|
||||
):
|
||||
start_node_ids.append(node_config.get('id'))
|
||||
|
||||
if len(start_node_ids) > 1:
|
||||
# add new fake iteration start node that connect to all start nodes
|
||||
root_node_id = f"{self.node_id}-start"
|
||||
graph_config['nodes'].append({
|
||||
"id": root_node_id,
|
||||
"data": {
|
||||
"title": "iteration start",
|
||||
"type": NodeType.ITERATION_START.value,
|
||||
}
|
||||
})
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
graph_config['edges'].append({
|
||||
"source": root_node_id,
|
||||
"target": start_node_id
|
||||
})
|
||||
else:
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
@ -156,6 +186,9 @@ class IterationNode(BaseNode):
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
@ -180,7 +213,11 @@ class IterationNode(BaseNode):
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
# move to next iteration
|
||||
next_index = variable_pool.get_any([self.node_id, 'index']) + 1
|
||||
current_index = variable_pool.get([self.node_id, 'index'])
|
||||
if current_index is None:
|
||||
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||
|
||||
next_index = int(current_index.to_object()) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
@ -229,6 +266,7 @@ class IterationNode(BaseNode):
|
||||
)
|
||||
break
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
yield event
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
|
40
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
40
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
@ -0,0 +1,40 @@
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
@ -5,6 +5,7 @@ from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
@ -30,6 +31,7 @@ node_classes = {
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
||||
|
@ -210,3 +210,217 @@ def test_run():
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
|
||||
assert count == 20
|
||||
|
||||
|
||||
def test_run_parallel():
|
||||
graph_config = {
|
||||
"edges": [{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
}, {
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
}, {
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
}, {
|
||||
"id": "tt-2-source-if-else-target",
|
||||
"source": "tt-2",
|
||||
"target": "if-else",
|
||||
}, {
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
}, {
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
}, {
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
}],
|
||||
"nodes": [{
|
||||
"data": {
|
||||
"title": "Start",
|
||||
"type": "start",
|
||||
"variables": []
|
||||
},
|
||||
"id": "start"
|
||||
}, {
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "tt",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}, {
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer"
|
||||
},
|
||||
"id": "answer-2"
|
||||
}, {
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"start_node_in_iteration": True,
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{
|
||||
"value_selector": ["sys", "query"],
|
||||
"variable": "arg1"
|
||||
}]
|
||||
},
|
||||
"id": "tt",
|
||||
},{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"start_node_in_iteration": True,
|
||||
"template": "{{ arg1 }} 321",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{
|
||||
"value_selector": ["sys", "query"],
|
||||
"variable": "arg1"
|
||||
}]
|
||||
},
|
||||
"id": "tt-2",
|
||||
}, {
|
||||
"data": {
|
||||
"answer": "{{#iteration-1.output#}}88888",
|
||||
"title": "answer 3",
|
||||
"type": "answer"
|
||||
},
|
||||
"id": "answer-3",
|
||||
}, {
|
||||
"data": {
|
||||
"conditions": [{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"]
|
||||
}],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else"
|
||||
},
|
||||
"id": "if-else",
|
||||
}, {
|
||||
"data": {
|
||||
"answer": "no hi",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 4",
|
||||
"type": "answer"
|
||||
},
|
||||
"id": "answer-4",
|
||||
}, {
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai"
|
||||
},
|
||||
"parameters": [{
|
||||
"description": "test",
|
||||
"name": "list_output",
|
||||
"required": False,
|
||||
"type": "array[string]"
|
||||
}],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor"
|
||||
},
|
||||
"id": "pe",
|
||||
}]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id='1',
|
||||
graph_config=graph_config,
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariableKey.QUERY: 'dify',
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||
SystemVariableKey.USER_ID: '1'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])
|
||||
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
}
|
||||
)
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={
|
||||
'iterator_selector': 'dify'
|
||||
},
|
||||
outputs={
|
||||
'output': 'dify 123'
|
||||
}
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
with patch.object(TemplateTransformNode, '_run', new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
# print(type(item), item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
|
||||
assert count == 32
|
||||
|
Loading…
x
Reference in New Issue
Block a user