fix: Introduce ArrayVariable and update iteration node to handle it (#12001)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-23 15:52:50 +08:00 committed by GitHub
parent 8978a6a3ff
commit 9cfd1c67b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 10 deletions

View File

@ -21,6 +21,7 @@ from .variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
@ -43,6 +44,7 @@ __all__ = [
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",
"ArrayVariable",
"FileSegment",
"FileVariable",
"FloatSegment",

View File

@ -10,6 +10,7 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
pass
class ArrayAnyVariable(ArrayAnySegment, Variable):
class ArrayVariable(ArraySegment, Variable):
pass
class ArrayStringVariable(ArrayStringSegment, Variable):
class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
pass
class ArrayNumberVariable(ArrayNumberSegment, Variable):
class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
pass
class ArrayObjectVariable(ArrayObjectSegment, Variable):
class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
pass
class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
pass

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.variables import IntegerVariable
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
@ -75,12 +75,15 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
Run the node.
"""
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment:
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0:
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -89,7 +92,7 @@ class IterationNode(BaseNode[IterationNodeData]):
)
return
iterator_list_value = iterator_list_segment.to_object()
iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")