diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index a6e953829e..b7ca250ff2 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -33,6 +33,15 @@ class Segment(BaseModel): def markdown(self) -> str: return str(self.value) + def to_object(self) -> Any: + if isinstance(self.value, Segment): + return self.value.to_object() + if isinstance(self.value, list): + return [v.to_object() for v in self.value] + if isinstance(self.value, dict): + return {k: v.to_object() for k, v in self.value.items()} + return self.value + class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 23076d5ca4..480f5cce92 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -4,7 +4,7 @@ from typing import Any, Union from typing_extensions import deprecated -from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory +from core.app.segments import Variable, factory from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable @@ -113,14 +113,7 @@ class VariablePool: raise ValueError('Invalid selector') hash_key = hash(tuple(selector[1:])) value = self._variable_dictionary[selector[0]].get(hash_key) - - if value is None: - return value - if isinstance(value, ArrayVariable): - return [element.value for element in value.value] - if isinstance(value, ObjectVariable): - return {k: v.value for k, v in value.value.items()} - return value.value if value else None + return value.to_object() if value else None def remove(self, selector: Sequence[str], /): """ diff --git a/api/tests/unit_tests/app/test_variables.py b/api/tests/unit_tests/app/test_variables.py index 65db88a4a8..05b080bdcf 100644 --- a/api/tests/unit_tests/app/test_variables.py +++ b/api/tests/unit_tests/app/test_variables.py @@ -9,6 +9,7 @@ from core.app.segments import ( StringVariable, factory, ) +from core.app.segments.variables import ArrayVariable, ObjectVariable def test_string_variable(): @@ -89,3 +90,47 @@ def test_build_a_blank_string(): ) assert isinstance(result, StringVariable) assert result.value == '' + + +def test_object_variable_to_object(): + var = ObjectVariable( + name='object', + value={ + 'key1': ObjectVariable( + name='object', + value={ + 'key2': StringVariable(name='key2', value='value2'), + }, + ), + 'key2': ArrayVariable( + name='array', + value=[ + StringVariable(name='key5_1', value='value5_1'), + IntegerVariable(name='key5_2', value=42), + ObjectVariable(name='key5_3', value={}), + ], + ), + }, + ) + + assert var.to_object() == { + 'key1': { + 'key2': 'value2', + }, + 'key2': [ + 'value5_1', + 42, + {}, + ], + } + + +def test_variable_to_object(): + var = StringVariable(name='text', value='text') + assert var.to_object() == 'text' + var = IntegerVariable(name='integer', value=42) + assert var.to_object() == 42 + var = FloatVariable(name='float', value=3.14) + assert var.to_object() == 3.14 + var = SecretVariable(name='secret', value='secret_value') + assert var.to_object() == 'secret_value'