From c6996a48a45a79b3dfa719fa0402de1662a9c0a3 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 26 Jul 2024 15:03:56 +0800 Subject: [PATCH] refactor(api/core/app/segments): Support more kinds of Segments. (#6706) Signed-off-by: -LAN- --- api/core/app/segments/__init__.py | 17 ++++- api/core/app/segments/factory.py | 53 +++++++------- api/core/app/segments/parser.py | 5 +- api/core/app/segments/segment_group.py | 17 +++-- api/core/app/segments/segments.py | 58 ++++++++++++++++ api/core/app/segments/types.py | 2 + api/core/app/segments/variables.py | 69 ++++++------------- api/core/workflow/entities/variable_pool.py | 12 ++-- api/core/workflow/nodes/tool/tool_node.py | 1 + .../unit_tests/{ => core}/app/test_segment.py | 0 .../{ => core}/app/test_variables.py | 9 +-- 11 files changed, 147 insertions(+), 96 deletions(-) rename api/tests/unit_tests/{ => core}/app/test_segment.py (100%) rename api/tests/unit_tests/{ => core}/app/test_variables.py (96%) diff --git a/api/core/app/segments/__init__.py b/api/core/app/segments/__init__.py index 0179d28887..b5d36bff3b 100644 --- a/api/core/app/segments/__init__.py +++ b/api/core/app/segments/__init__.py @@ -1,5 +1,14 @@ from .segment_group import SegmentGroup -from .segments import NoneSegment, Segment +from .segments import ( + ArraySegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) from .types import SegmentType from .variables import ( ArrayVariable, @@ -27,4 +36,10 @@ __all__ = [ 'Segment', 'NoneSegment', 'NoneVariable', + 'IntegerSegment', + 'FloatSegment', + 'ObjectSegment', + 'ArraySegment', + 'FileSegment', + 'StringSegment', ] diff --git a/api/core/app/segments/factory.py b/api/core/app/segments/factory.py index 187042ec03..8e77b43bb7 100644 --- a/api/core/app/segments/factory.py +++ b/api/core/app/segments/factory.py @@ -3,15 +3,20 @@ from typing import Any from core.file.file_obj import FileVar -from .segments import Segment, StringSegment +from .segments import ( + ArraySegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) from .types import SegmentType from .variables import ( - ArrayVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneVariable, - ObjectVariable, SecretVariable, StringVariable, Variable, @@ -39,29 +44,23 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable: raise ValueError(f'not supported value type {value_type}') -def build_anonymous_variable(value: Any, /) -> Variable: - if value is None: - return NoneVariable(name='anonymous') - if isinstance(value, str): - return StringVariable(name='anonymous', value=value) - if isinstance(value, int): - return IntegerVariable(name='anonymous', value=value) - if isinstance(value, float): - return FloatVariable(name='anonymous', value=value) - if isinstance(value, dict): - # TODO: Limit the depth of the object - obj = {k: build_anonymous_variable(v) for k, v in value.items()} - return ObjectVariable(name='anonymous', value=obj) - if isinstance(value, list): - # TODO: Limit the depth of the array - elements = [build_anonymous_variable(v) for v in value] - return ArrayVariable(name='anonymous', value=elements) - if isinstance(value, FileVar): - return FileVariable(name='anonymous', value=value) - raise ValueError(f'not supported value {value}') - - def build_segment(value: Any, /) -> Segment: + if value is None: + return NoneSegment() if isinstance(value, str): return StringSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + # TODO: Limit the depth of the object + obj = {k: build_segment(v) for k, v in value.items()} + return ObjectSegment(value=obj) + if isinstance(value, list): + # TODO: Limit the depth of the array + elements = [build_segment(v) for v in value] + return ArraySegment(value=elements) + if isinstance(value, FileVar): + return FileSegment(value=value) raise ValueError(f'not supported value {value}') diff --git a/api/core/app/segments/parser.py b/api/core/app/segments/parser.py index 21d1b89541..7061a4d878 100644 --- a/api/core/app/segments/parser.py +++ b/api/core/app/segments/parser.py @@ -1,8 +1,9 @@ import re -from core.app.segments import SegmentGroup, factory from core.workflow.entities.variable_pool import VariablePool +from . import SegmentGroup, factory + VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') @@ -14,4 +15,4 @@ def convert_template(*, template: str, variable_pool: VariablePool): segments.append(value) else: segments.append(factory.build_segment(part)) - return SegmentGroup(segments=segments) + return SegmentGroup(value=segments) diff --git a/api/core/app/segments/segment_group.py b/api/core/app/segments/segment_group.py index 0d5176b885..b4ff09b6d3 100644 --- a/api/core/app/segments/segment_group.py +++ b/api/core/app/segments/segment_group.py @@ -1,19 +1,22 @@ -from pydantic import BaseModel - from .segments import Segment +from .types import SegmentType -class SegmentGroup(BaseModel): - segments: list[Segment] +class SegmentGroup(Segment): + value_type: SegmentType = SegmentType.GROUP + value: list[Segment] @property def text(self): - return ''.join([segment.text for segment in self.segments]) + return ''.join([segment.text for segment in self.value]) @property def log(self): - return ''.join([segment.log for segment in self.segments]) + return ''.join([segment.log for segment in self.value]) @property def markdown(self): - return ''.join([segment.markdown for segment in self.segments]) \ No newline at end of file + return ''.join([segment.markdown for segment in self.value]) + + def to_object(self): + return [segment.to_object() for segment in self.value] diff --git a/api/core/app/segments/segments.py b/api/core/app/segments/segments.py index afd383880f..f317054bc7 100644 --- a/api/core/app/segments/segments.py +++ b/api/core/app/segments/segments.py @@ -1,7 +1,11 @@ +import json +from collections.abc import Mapping, Sequence from typing import Any from pydantic import BaseModel, ConfigDict, field_validator +from core.file.file_obj import FileVar + from .types import SegmentType @@ -57,3 +61,57 @@ class NoneSegment(Segment): class StringSegment(Segment): value_type: SegmentType = SegmentType.STRING value: str + +class FloatSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: float + + +class IntegerSegment(Segment): + value_type: SegmentType = SegmentType.NUMBER + value: int + + +class ObjectSegment(Segment): + value_type: SegmentType = SegmentType.OBJECT + value: Mapping[str, Segment] + + @property + def text(self) -> str: + # TODO: Process variables. + return json.dumps(self.model_dump()['value'], ensure_ascii=False) + + @property + def log(self) -> str: + # TODO: Process variables. + return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + + @property + def markdown(self) -> str: + # TODO: Use markdown code block + return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) + + def to_object(self): + return {k: v.to_object() for k, v in self.value.items()} + + +class ArraySegment(Segment): + value_type: SegmentType = SegmentType.ARRAY + value: Sequence[Segment] + + @property + def markdown(self) -> str: + return '\n'.join(['- ' + item.markdown for item in self.value]) + + def to_object(self): + return [v.to_object() for v in self.value] + + +class FileSegment(Segment): + value_type: SegmentType = SegmentType.FILE + # TODO: embed FileVar in this model. + value: FileVar + + @property + def markdown(self) -> str: + return self.value.to_markdown() diff --git a/api/core/app/segments/types.py b/api/core/app/segments/types.py index ebcbf507c6..133755bbc6 100644 --- a/api/core/app/segments/types.py +++ b/api/core/app/segments/types.py @@ -9,3 +9,5 @@ class SegmentType(str, Enum): ARRAY = 'array' OBJECT = 'object' FILE = 'file' + + GROUP = 'group' diff --git a/api/core/app/segments/variables.py b/api/core/app/segments/variables.py index 5edaccc4d6..ba55022726 100644 --- a/api/core/app/segments/variables.py +++ b/api/core/app/segments/variables.py @@ -1,12 +1,18 @@ -import json -from collections.abc import Mapping, Sequence from pydantic import Field -from core.file.file_obj import FileVar from core.helper import encrypter -from .segments import NoneSegment, Segment, StringSegment +from .segments import ( + ArraySegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) from .types import SegmentType @@ -27,59 +33,24 @@ class StringVariable(StringSegment, Variable): pass -class FloatVariable(Variable): - value_type: SegmentType = SegmentType.NUMBER - value: float +class FloatVariable(FloatSegment, Variable): + pass -class IntegerVariable(Variable): - value_type: SegmentType = SegmentType.NUMBER - value: int +class IntegerVariable(IntegerSegment, Variable): + pass -class ObjectVariable(Variable): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Variable] - - @property - def text(self) -> str: - # TODO: Process variables. - return json.dumps(self.model_dump()['value'], ensure_ascii=False) - - @property - def log(self) -> str: - # TODO: Process variables. - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - # TODO: Use markdown code block - return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2) - - def to_object(self): - return {k: v.to_object() for k, v in self.value.items()} +class ObjectVariable(ObjectSegment, Variable): + pass -class ArrayVariable(Variable): - value_type: SegmentType = SegmentType.ARRAY - value: Sequence[Variable] - - @property - def markdown(self) -> str: - return '\n'.join(['- ' + item.markdown for item in self.value]) - - def to_object(self): - return [v.to_object() for v in self.value] +class ArrayVariable(ArraySegment, Variable): + pass -class FileVariable(Variable): - value_type: SegmentType = SegmentType.FILE - # TODO: embed FileVar in this model. - value: FileVar - - @property - def markdown(self) -> str: - return self.value.to_markdown() +class FileVariable(FileSegment, Variable): + pass class SecretVariable(StringVariable): diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 480f5cce92..a27b4261e4 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -4,7 +4,7 @@ from typing import Any, Union from typing_extensions import deprecated -from core.app.segments import Variable, factory +from core.app.segments import Segment, Variable, factory from core.file.file_obj import FileVar from core.workflow.entities.node_entities import SystemVariable @@ -33,7 +33,7 @@ class VariablePool: # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict) + self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict) # TODO: This user inputs is not used for pool. self.user_inputs = user_inputs @@ -67,15 +67,15 @@ class VariablePool: if value is None: return - if not isinstance(value, Variable): - v = factory.build_anonymous_variable(value) - else: + if isinstance(value, Segment): v = value + else: + v = factory.build_segment(value) hash_key = hash(tuple(selector[1:])) self._variable_dictionary[selector[0]][hash_key] = v - def get(self, selector: Sequence[str], /) -> Variable | None: + def get(self, selector: Sequence[str], /) -> Segment | None: """ Retrieves the value from the variable pool based on the given selector. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 238477117d..c03a17468a 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -126,6 +126,7 @@ class ToolNode(BaseNode): else: tool_input = node_data.tool_parameters[parameter_name] if tool_input.type == 'variable': + # TODO: check if the variable exists in the variable pool parameter_value = variable_pool.get(tool_input.value).value else: segment_group = parser.convert_template( diff --git a/api/tests/unit_tests/app/test_segment.py b/api/tests/unit_tests/core/app/test_segment.py similarity index 100% rename from api/tests/unit_tests/app/test_segment.py rename to api/tests/unit_tests/core/app/test_segment.py diff --git a/api/tests/unit_tests/app/test_variables.py b/api/tests/unit_tests/core/app/test_variables.py similarity index 96% rename from api/tests/unit_tests/app/test_variables.py rename to api/tests/unit_tests/core/app/test_variables.py index 40872c8d53..afed29e3cb 100644 --- a/api/tests/unit_tests/app/test_variables.py +++ b/api/tests/unit_tests/core/app/test_variables.py @@ -5,7 +5,8 @@ from core.app.segments import ( ArrayVariable, FloatVariable, IntegerVariable, - NoneVariable, + NoneSegment, + ObjectSegment, ObjectVariable, SecretVariable, SegmentType, @@ -139,10 +140,10 @@ def test_variable_to_object(): def test_build_a_object_variable_with_none_value(): - var = factory.build_anonymous_variable( + var = factory.build_segment( { 'key1': None, } ) - assert isinstance(var, ObjectVariable) - assert isinstance(var.value['key1'], NoneVariable) + assert isinstance(var, ObjectSegment) + assert isinstance(var.value['key1'], NoneSegment)