refactor(api/core/app/segments): Support more kinds of Segments. (#6706)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-07-26 15:03:56 +08:00 committed by GitHub
parent 6b50bb0fe6
commit c6996a48a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 147 additions and 96 deletions

View File

@ -1,5 +1,14 @@
from .segment_group import SegmentGroup from .segment_group import SegmentGroup
from .segments import NoneSegment, Segment from .segments import (
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType from .types import SegmentType
from .variables import ( from .variables import (
ArrayVariable, ArrayVariable,
@ -27,4 +36,10 @@ __all__ = [
'Segment', 'Segment',
'NoneSegment', 'NoneSegment',
'NoneVariable', 'NoneVariable',
'IntegerSegment',
'FloatSegment',
'ObjectSegment',
'ArraySegment',
'FileSegment',
'StringSegment',
] ]

View File

@ -3,15 +3,20 @@ from typing import Any
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from .segments import Segment, StringSegment from .segments import (
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType from .types import SegmentType
from .variables import ( from .variables import (
ArrayVariable,
FileVariable,
FloatVariable, FloatVariable,
IntegerVariable, IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, Variable,
@ -39,29 +44,23 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
raise ValueError(f'not supported value type {value_type}') raise ValueError(f'not supported value type {value_type}')
def build_anonymous_variable(value: Any, /) -> Variable:
if value is None:
return NoneVariable(name='anonymous')
if isinstance(value, str):
return StringVariable(name='anonymous', value=value)
if isinstance(value, int):
return IntegerVariable(name='anonymous', value=value)
if isinstance(value, float):
return FloatVariable(name='anonymous', value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
return ObjectVariable(name='anonymous', value=obj)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_anonymous_variable(v) for v in value]
return ArrayVariable(name='anonymous', value=elements)
if isinstance(value, FileVar):
return FileVariable(name='anonymous', value=value)
raise ValueError(f'not supported value {value}')
def build_segment(value: Any, /) -> Segment: def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()
if isinstance(value, str): if isinstance(value, str):
return StringSegment(value=value) return StringSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
obj = {k: build_segment(v) for k, v in value.items()}
return ObjectSegment(value=obj)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_segment(v) for v in value]
return ArraySegment(value=elements)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}') raise ValueError(f'not supported value {value}')

View File

@ -1,8 +1,9 @@
import re import re
from core.app.segments import SegmentGroup, factory
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}') VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
@ -14,4 +15,4 @@ def convert_template(*, template: str, variable_pool: VariablePool):
segments.append(value) segments.append(value)
else: else:
segments.append(factory.build_segment(part)) segments.append(factory.build_segment(part))
return SegmentGroup(segments=segments) return SegmentGroup(value=segments)

View File

@ -1,19 +1,22 @@
from pydantic import BaseModel
from .segments import Segment from .segments import Segment
from .types import SegmentType
class SegmentGroup(BaseModel): class SegmentGroup(Segment):
segments: list[Segment] value_type: SegmentType = SegmentType.GROUP
value: list[Segment]
@property @property
def text(self): def text(self):
return ''.join([segment.text for segment in self.segments]) return ''.join([segment.text for segment in self.value])
@property @property
def log(self): def log(self):
return ''.join([segment.log for segment in self.segments]) return ''.join([segment.log for segment in self.value])
@property @property
def markdown(self): def markdown(self):
return ''.join([segment.markdown for segment in self.segments]) return ''.join([segment.markdown for segment in self.value])
def to_object(self):
return [segment.to_object() for segment in self.value]

View File

@ -1,7 +1,11 @@
import json
from collections.abc import Mapping, Sequence
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from core.file.file_obj import FileVar
from .types import SegmentType from .types import SegmentType
@ -57,3 +61,57 @@ class NoneSegment(Segment):
class StringSegment(Segment): class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING value_type: SegmentType = SegmentType.STRING
value: str value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: float
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Segment]
@property
def text(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
@property
def log(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
# TODO: Use markdown code block
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
def to_object(self):
return {k: v.to_object() for k, v in self.value.items()}
class ArraySegment(Segment):
value_type: SegmentType = SegmentType.ARRAY
value: Sequence[Segment]
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()

View File

@ -9,3 +9,5 @@ class SegmentType(str, Enum):
ARRAY = 'array' ARRAY = 'array'
OBJECT = 'object' OBJECT = 'object'
FILE = 'file' FILE = 'file'
GROUP = 'group'

View File

@ -1,12 +1,18 @@
import json
from collections.abc import Mapping, Sequence
from pydantic import Field from pydantic import Field
from core.file.file_obj import FileVar
from core.helper import encrypter from core.helper import encrypter
from .segments import NoneSegment, Segment, StringSegment from .segments import (
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType from .types import SegmentType
@ -27,59 +33,24 @@ class StringVariable(StringSegment, Variable):
pass pass
class FloatVariable(Variable): class FloatVariable(FloatSegment, Variable):
value_type: SegmentType = SegmentType.NUMBER pass
value: float
class IntegerVariable(Variable): class IntegerVariable(IntegerSegment, Variable):
value_type: SegmentType = SegmentType.NUMBER pass
value: int
class ObjectVariable(Variable): class ObjectVariable(ObjectSegment, Variable):
value_type: SegmentType = SegmentType.OBJECT pass
value: Mapping[str, Variable]
@property
def text(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
@property
def log(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
# TODO: Use markdown code block
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
def to_object(self):
return {k: v.to_object() for k, v in self.value.items()}
class ArrayVariable(Variable): class ArrayVariable(ArraySegment, Variable):
value_type: SegmentType = SegmentType.ARRAY pass
value: Sequence[Variable]
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class FileVariable(Variable): class FileVariable(FileSegment, Variable):
value_type: SegmentType = SegmentType.FILE pass
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class SecretVariable(StringVariable): class SecretVariable(StringVariable):

View File

@ -4,7 +4,7 @@ from typing import Any, Union
from typing_extensions import deprecated from typing_extensions import deprecated
from core.app.segments import Variable, factory from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
@ -33,7 +33,7 @@ class VariablePool:
# The first element of the selector is the node id, it's the first-level key in the dictionary. # The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one. # elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict) self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
# TODO: This user inputs is not used for pool. # TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs self.user_inputs = user_inputs
@ -67,15 +67,15 @@ class VariablePool:
if value is None: if value is None:
return return
if not isinstance(value, Variable): if isinstance(value, Segment):
v = factory.build_anonymous_variable(value)
else:
v = value v = value
else:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v self._variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None: def get(self, selector: Sequence[str], /) -> Segment | None:
""" """
Retrieves the value from the variable pool based on the given selector. Retrieves the value from the variable pool based on the given selector.

View File

@ -126,6 +126,7 @@ class ToolNode(BaseNode):
else: else:
tool_input = node_data.tool_parameters[parameter_name] tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable': if tool_input.type == 'variable':
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value parameter_value = variable_pool.get(tool_input.value).value
else: else:
segment_group = parser.convert_template( segment_group = parser.convert_template(

View File

@ -5,7 +5,8 @@ from core.app.segments import (
ArrayVariable, ArrayVariable,
FloatVariable, FloatVariable,
IntegerVariable, IntegerVariable,
NoneVariable, NoneSegment,
ObjectSegment,
ObjectVariable, ObjectVariable,
SecretVariable, SecretVariable,
SegmentType, SegmentType,
@ -139,10 +140,10 @@ def test_variable_to_object():
def test_build_a_object_variable_with_none_value(): def test_build_a_object_variable_with_none_value():
var = factory.build_anonymous_variable( var = factory.build_segment(
{ {
'key1': None, 'key1': None,
} }
) )
assert isinstance(var, ObjectVariable) assert isinstance(var, ObjectSegment)
assert isinstance(var.value['key1'], NoneVariable) assert isinstance(var.value['key1'], NoneSegment)