mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-16 17:25:56 +08:00
fix iteration start node
This commit is contained in:
parent
d6da7b0336
commit
ec4fc784f0
@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
|
|||||||
desc: Optional[str] = None
|
desc: Optional[str] = None
|
||||||
|
|
||||||
class BaseIterationNodeData(BaseNodeData):
|
class BaseIterationNodeData(BaseNodeData):
|
||||||
start_node_id: str
|
start_node_id: Optional[str] = None
|
||||||
|
|
||||||
class BaseIterationState(BaseModel):
|
class BaseIterationState(BaseModel):
|
||||||
iteration_node_id: str
|
iteration_node_id: str
|
||||||
|
@ -28,6 +28,7 @@ class NodeType(Enum):
|
|||||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||||
LOOP = 'loop'
|
LOOP = 'loop'
|
||||||
ITERATION = 'iteration'
|
ITERATION = 'iteration'
|
||||||
|
ITERATION_START = 'iteration-start' # fake start node for iteration
|
||||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||||
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
class IterationNodeData(BaseIterationNodeData):
|
class IterationNodeData(BaseIterationNodeData):
|
||||||
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
|
|||||||
iterator_selector: list[str] # variable selector
|
iterator_selector: list[str] # variable selector
|
||||||
output_selector: list[str] # output selector
|
output_selector: list[str] # output selector
|
||||||
|
|
||||||
|
|
||||||
|
class IterationStartNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Iteration Start Node Data.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class IterationState(BaseIterationState):
|
class IterationState(BaseIterationState):
|
||||||
"""
|
"""
|
||||||
Iteration State.
|
Iteration State.
|
||||||
|
@ -50,9 +50,39 @@ class IterationNode(BaseNode):
|
|||||||
"iterator_selector": iterator_list_value
|
"iterator_selector": iterator_list_value
|
||||||
}
|
}
|
||||||
|
|
||||||
root_node_id = self.node_data.start_node_id
|
|
||||||
graph_config = self.graph_config
|
graph_config = self.graph_config
|
||||||
|
|
||||||
|
# find nodes in current iteration and donot have source and have have start_node_in_iteration flag
|
||||||
|
# these nodes are the start nodes of the iteration (in version of parallel support)
|
||||||
|
start_node_ids = []
|
||||||
|
for node_config in graph_config['nodes']:
|
||||||
|
if (
|
||||||
|
node_config.get('data', {}).get('iteration_id')
|
||||||
|
and node_config.get('data', {}).get('iteration_id') == self.node_id
|
||||||
|
and not node_config.get('source')
|
||||||
|
and node_config.get('data', {}).get('start_node_in_iteration', False)
|
||||||
|
):
|
||||||
|
start_node_ids.append(node_config.get('id'))
|
||||||
|
|
||||||
|
if len(start_node_ids) > 1:
|
||||||
|
# add new fake iteration start node that connect to all start nodes
|
||||||
|
root_node_id = f"{self.node_id}-start"
|
||||||
|
graph_config['nodes'].append({
|
||||||
|
"id": root_node_id,
|
||||||
|
"data": {
|
||||||
|
"title": "iteration start",
|
||||||
|
"type": NodeType.ITERATION_START.value,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for start_node_id in start_node_ids:
|
||||||
|
graph_config['edges'].append({
|
||||||
|
"source": root_node_id,
|
||||||
|
"target": start_node_id
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
root_node_id = self.node_data.start_node_id
|
||||||
|
|
||||||
# init graph
|
# init graph
|
||||||
iteration_graph = Graph.init(
|
iteration_graph = Graph.init(
|
||||||
graph_config=graph_config,
|
graph_config=graph_config,
|
||||||
@ -156,6 +186,9 @@ class IterationNode(BaseNode):
|
|||||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||||
event.in_iteration_id = self.node_id
|
event.in_iteration_id = self.node_id
|
||||||
|
|
||||||
|
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
if event.route_node_state.node_run_result:
|
if event.route_node_state.node_run_result:
|
||||||
metadata = event.route_node_state.node_run_result.metadata
|
metadata = event.route_node_state.node_run_result.metadata
|
||||||
@ -180,7 +213,11 @@ class IterationNode(BaseNode):
|
|||||||
variable_pool.remove_node(node_id)
|
variable_pool.remove_node(node_id)
|
||||||
|
|
||||||
# move to next iteration
|
# move to next iteration
|
||||||
next_index = variable_pool.get_any([self.node_id, 'index']) + 1
|
current_index = variable_pool.get([self.node_id, 'index'])
|
||||||
|
if current_index is None:
|
||||||
|
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||||
|
|
||||||
|
next_index = int(current_index.to_object()) + 1
|
||||||
variable_pool.add(
|
variable_pool.add(
|
||||||
[self.node_id, 'index'],
|
[self.node_id, 'index'],
|
||||||
next_index
|
next_index
|
||||||
@ -229,6 +266,7 @@ class IterationNode(BaseNode):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
event = cast(InNodeEvent, event)
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
yield IterationRunSucceededEvent(
|
yield IterationRunSucceededEvent(
|
||||||
|
40
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
40
api/core/workflow/nodes/iteration/iteration_start_node.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
|
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
class IterationStartNode(BaseNode):
|
||||||
|
"""
|
||||||
|
Iteration Start Node.
|
||||||
|
"""
|
||||||
|
_node_data_cls = IterationStartNodeData
|
||||||
|
_node_type = NodeType.ITERATION_START
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
Run the node.
|
||||||
|
"""
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
cls,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
node_data: IterationNodeData
|
||||||
|
) -> Mapping[str, Sequence[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
:param graph_config: graph config
|
||||||
|
:param node_id: node id
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return {}
|
@ -5,6 +5,7 @@ from core.workflow.nodes.end.end_node import EndNode
|
|||||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||||
|
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
|
||||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
@ -30,6 +31,7 @@ node_classes = {
|
|||||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||||
NodeType.ITERATION: IterationNode,
|
NodeType.ITERATION: IterationNode,
|
||||||
|
NodeType.ITERATION_START: IterationStartNode,
|
||||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||||
}
|
}
|
||||||
|
@ -210,3 +210,217 @@ def test_run():
|
|||||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||||
|
|
||||||
assert count == 20
|
assert count == 20
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_parallel():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [{
|
||||||
|
"id": "start-source-pe-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "pe",
|
||||||
|
}, {
|
||||||
|
"id": "iteration-1-source-answer-3-target",
|
||||||
|
"source": "iteration-1",
|
||||||
|
"target": "answer-3",
|
||||||
|
}, {
|
||||||
|
"id": "tt-source-if-else-target",
|
||||||
|
"source": "tt",
|
||||||
|
"target": "if-else",
|
||||||
|
}, {
|
||||||
|
"id": "tt-2-source-if-else-target",
|
||||||
|
"source": "tt-2",
|
||||||
|
"target": "if-else",
|
||||||
|
}, {
|
||||||
|
"id": "if-else-true-answer-2-target",
|
||||||
|
"source": "if-else",
|
||||||
|
"sourceHandle": "true",
|
||||||
|
"target": "answer-2",
|
||||||
|
}, {
|
||||||
|
"id": "if-else-false-answer-4-target",
|
||||||
|
"source": "if-else",
|
||||||
|
"sourceHandle": "false",
|
||||||
|
"target": "answer-4",
|
||||||
|
}, {
|
||||||
|
"id": "pe-source-iteration-1-target",
|
||||||
|
"source": "pe",
|
||||||
|
"target": "iteration-1",
|
||||||
|
}],
|
||||||
|
"nodes": [{
|
||||||
|
"data": {
|
||||||
|
"title": "Start",
|
||||||
|
"type": "start",
|
||||||
|
"variables": []
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"iterator_selector": ["pe", "list_output"],
|
||||||
|
"output_selector": ["tt", "output"],
|
||||||
|
"output_type": "array[string]",
|
||||||
|
"startNodeType": "template-transform",
|
||||||
|
"start_node_id": "tt",
|
||||||
|
"title": "iteration",
|
||||||
|
"type": "iteration",
|
||||||
|
},
|
||||||
|
"id": "iteration-1",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"answer": "{{#tt.output#}}",
|
||||||
|
"iteration_id": "iteration-1",
|
||||||
|
"title": "answer 2",
|
||||||
|
"type": "answer"
|
||||||
|
},
|
||||||
|
"id": "answer-2"
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"iteration_id": "iteration-1",
|
||||||
|
"start_node_in_iteration": True,
|
||||||
|
"template": "{{ arg1 }} 123",
|
||||||
|
"title": "template transform",
|
||||||
|
"type": "template-transform",
|
||||||
|
"variables": [{
|
||||||
|
"value_selector": ["sys", "query"],
|
||||||
|
"variable": "arg1"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"id": "tt",
|
||||||
|
},{
|
||||||
|
"data": {
|
||||||
|
"iteration_id": "iteration-1",
|
||||||
|
"start_node_in_iteration": True,
|
||||||
|
"template": "{{ arg1 }} 321",
|
||||||
|
"title": "template transform",
|
||||||
|
"type": "template-transform",
|
||||||
|
"variables": [{
|
||||||
|
"value_selector": ["sys", "query"],
|
||||||
|
"variable": "arg1"
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"id": "tt-2",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"answer": "{{#iteration-1.output#}}88888",
|
||||||
|
"title": "answer 3",
|
||||||
|
"type": "answer"
|
||||||
|
},
|
||||||
|
"id": "answer-3",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"conditions": [{
|
||||||
|
"comparison_operator": "is",
|
||||||
|
"id": "1721916275284",
|
||||||
|
"value": "hi",
|
||||||
|
"variable_selector": ["sys", "query"]
|
||||||
|
}],
|
||||||
|
"iteration_id": "iteration-1",
|
||||||
|
"logical_operator": "and",
|
||||||
|
"title": "if",
|
||||||
|
"type": "if-else"
|
||||||
|
},
|
||||||
|
"id": "if-else",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"answer": "no hi",
|
||||||
|
"iteration_id": "iteration-1",
|
||||||
|
"title": "answer 4",
|
||||||
|
"type": "answer"
|
||||||
|
},
|
||||||
|
"id": "answer-4",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"instruction": "test1",
|
||||||
|
"model": {
|
||||||
|
"completion_params": {
|
||||||
|
"temperature": 0.7
|
||||||
|
},
|
||||||
|
"mode": "chat",
|
||||||
|
"name": "gpt-4o",
|
||||||
|
"provider": "openai"
|
||||||
|
},
|
||||||
|
"parameters": [{
|
||||||
|
"description": "test",
|
||||||
|
"name": "list_output",
|
||||||
|
"required": False,
|
||||||
|
"type": "array[string]"
|
||||||
|
}],
|
||||||
|
"query": ["sys", "query"],
|
||||||
|
"reasoning_mode": "prompt",
|
||||||
|
"title": "pe",
|
||||||
|
"type": "parameter-extractor"
|
||||||
|
},
|
||||||
|
"id": "pe",
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
init_params = GraphInitParams(
|
||||||
|
tenant_id='1',
|
||||||
|
app_id='1',
|
||||||
|
workflow_type=WorkflowType.CHAT,
|
||||||
|
workflow_id='1',
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id='1',
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct variable pool
|
||||||
|
pool = VariablePool(system_variables={
|
||||||
|
SystemVariableKey.QUERY: 'dify',
|
||||||
|
SystemVariableKey.FILES: [],
|
||||||
|
SystemVariableKey.CONVERSATION_ID: 'abababa',
|
||||||
|
SystemVariableKey.USER_ID: '1'
|
||||||
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
pool.add(['pe', 'list_output'], ["dify-1", "dify-2"])
|
||||||
|
|
||||||
|
iteration_node = IterationNode(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=pool,
|
||||||
|
start_at=time.perf_counter()
|
||||||
|
),
|
||||||
|
config={
|
||||||
|
"data": {
|
||||||
|
"iterator_selector": ["pe", "list_output"],
|
||||||
|
"output_selector": ["tt", "output"],
|
||||||
|
"output_type": "array[string]",
|
||||||
|
"startNodeType": "template-transform",
|
||||||
|
"title": "迭代",
|
||||||
|
"type": "iteration",
|
||||||
|
},
|
||||||
|
"id": "iteration-1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def tt_generator(self):
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs={
|
||||||
|
'iterator_selector': 'dify'
|
||||||
|
},
|
||||||
|
outputs={
|
||||||
|
'output': 'dify 123'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# print("")
|
||||||
|
|
||||||
|
with patch.object(TemplateTransformNode, '_run', new=tt_generator):
|
||||||
|
# execute node
|
||||||
|
result = iteration_node._run()
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for item in result:
|
||||||
|
# print(type(item), item)
|
||||||
|
count += 1
|
||||||
|
if isinstance(item, RunCompletedEvent):
|
||||||
|
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||||
|
|
||||||
|
assert count == 32
|
||||||
|
Loading…
x
Reference in New Issue
Block a user