fix: Introduce ArrayVariable and update iteration node to handle it (#12001)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-23 15:52:50 +08:00 committed by GitHub
parent 8978a6a3ff
commit 9cfd1c67b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 10 deletions

View File

@ -21,6 +21,7 @@ from .variables import (
ArrayNumberVariable, ArrayNumberVariable,
ArrayObjectVariable, ArrayObjectVariable,
ArrayStringVariable, ArrayStringVariable,
ArrayVariable,
FileVariable, FileVariable,
FloatVariable, FloatVariable,
IntegerVariable, IntegerVariable,
@ -43,6 +44,7 @@ __all__ = [
"ArraySegment", "ArraySegment",
"ArrayStringSegment", "ArrayStringSegment",
"ArrayStringVariable", "ArrayStringVariable",
"ArrayVariable",
"FileSegment", "FileSegment",
"FileVariable", "FileVariable",
"FloatSegment", "FloatSegment",

View File

@ -10,6 +10,7 @@ from .segments import (
ArrayFileSegment, ArrayFileSegment,
ArrayNumberSegment, ArrayNumberSegment,
ArrayObjectSegment, ArrayObjectSegment,
ArraySegment,
ArrayStringSegment, ArrayStringSegment,
FileSegment, FileSegment,
FloatSegment, FloatSegment,
@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
pass pass
class ArrayAnyVariable(ArrayAnySegment, Variable): class ArrayVariable(ArraySegment, Variable):
pass pass
class ArrayStringVariable(ArrayStringSegment, Variable): class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
pass pass
class ArrayNumberVariable(ArrayNumberSegment, Variable): class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
pass pass
class ArrayObjectVariable(ArrayObjectSegment, Variable): class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
pass
class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
pass pass

View File

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.variables import IntegerVariable from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import ( from core.workflow.entities.node_entities import (
NodeRunMetadataKey, NodeRunMetadataKey,
NodeRunResult, NodeRunResult,
@ -75,12 +75,15 @@ class IterationNode(BaseNode[IterationNodeData]):
""" """
Run the node. Run the node.
""" """
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not iterator_list_segment: if not variable:
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
if len(iterator_list_segment.value) == 0: if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -89,7 +92,7 @@ class IterationNode(BaseNode[IterationNodeData]):
) )
return return
iterator_list_value = iterator_list_segment.to_object() iterator_list_value = variable.to_object()
if not isinstance(iterator_list_value, list): if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")