refactor(api/core/app/segments): Support more kinds of Segments. (#6706)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-07-26 15:03:56 +08:00 committed by GitHub
parent 6b50bb0fe6
commit c6996a48a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 147 additions and 96 deletions

View File

@ -1,5 +1,14 @@
from .segment_group import SegmentGroup
from .segments import NoneSegment, Segment
from .segments import (
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
from .variables import (
ArrayVariable,
@ -27,4 +36,10 @@ __all__ = [
'Segment',
'NoneSegment',
'NoneVariable',
'IntegerSegment',
'FloatSegment',
'ObjectSegment',
'ArraySegment',
'FileSegment',
'StringSegment',
]

View File

@ -3,15 +3,20 @@ from typing import Any
from core.file.file_obj import FileVar
from .segments import Segment, StringSegment
from .segments import (
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
from .variables import (
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
@ -39,29 +44,23 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
raise ValueError(f'not supported value type {value_type}')
def build_anonymous_variable(value: Any, /) -> Variable:
if value is None:
return NoneVariable(name='anonymous')
if isinstance(value, str):
return StringVariable(name='anonymous', value=value)
if isinstance(value, int):
return IntegerVariable(name='anonymous', value=value)
if isinstance(value, float):
return FloatVariable(name='anonymous', value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
return ObjectVariable(name='anonymous', value=obj)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_anonymous_variable(v) for v in value]
return ArrayVariable(name='anonymous', value=elements)
if isinstance(value, FileVar):
return FileVariable(name='anonymous', value=value)
raise ValueError(f'not supported value {value}')
def build_segment(value: Any, /) -> Segment:
if value is None:
return NoneSegment()
if isinstance(value, str):
return StringSegment(value=value)
if isinstance(value, int):
return IntegerSegment(value=value)
if isinstance(value, float):
return FloatSegment(value=value)
if isinstance(value, dict):
# TODO: Limit the depth of the object
obj = {k: build_segment(v) for k, v in value.items()}
return ObjectSegment(value=obj)
if isinstance(value, list):
# TODO: Limit the depth of the array
elements = [build_segment(v) for v in value]
return ArraySegment(value=elements)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')

View File

@ -1,8 +1,9 @@
import re
from core.app.segments import SegmentGroup, factory
from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
@ -14,4 +15,4 @@ def convert_template(*, template: str, variable_pool: VariablePool):
segments.append(value)
else:
segments.append(factory.build_segment(part))
return SegmentGroup(segments=segments)
return SegmentGroup(value=segments)

View File

@ -1,19 +1,22 @@
from pydantic import BaseModel
from .segments import Segment
from .types import SegmentType
class SegmentGroup(BaseModel):
segments: list[Segment]
class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP
value: list[Segment]
@property
def text(self):
return ''.join([segment.text for segment in self.segments])
return ''.join([segment.text for segment in self.value])
@property
def log(self):
return ''.join([segment.log for segment in self.segments])
return ''.join([segment.log for segment in self.value])
@property
def markdown(self):
return ''.join([segment.markdown for segment in self.segments])
return ''.join([segment.markdown for segment in self.value])
def to_object(self):
return [segment.to_object() for segment in self.value]

View File

@ -1,7 +1,11 @@
import json
from collections.abc import Mapping, Sequence
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from core.file.file_obj import FileVar
from .types import SegmentType
@ -57,3 +61,57 @@ class NoneSegment(Segment):
class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING
value: str
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: float
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Segment]
@property
def text(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
@property
def log(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
# TODO: Use markdown code block
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
def to_object(self):
return {k: v.to_object() for k, v in self.value.items()}
class ArraySegment(Segment):
value_type: SegmentType = SegmentType.ARRAY
value: Sequence[Segment]
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()

View File

@ -9,3 +9,5 @@ class SegmentType(str, Enum):
ARRAY = 'array'
OBJECT = 'object'
FILE = 'file'
GROUP = 'group'

View File

@ -1,12 +1,18 @@
import json
from collections.abc import Mapping, Sequence
from pydantic import Field
from core.file.file_obj import FileVar
from core.helper import encrypter
from .segments import NoneSegment, Segment, StringSegment
from .segments import (
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
StringSegment,
)
from .types import SegmentType
@ -27,59 +33,24 @@ class StringVariable(StringSegment, Variable):
pass
class FloatVariable(Variable):
value_type: SegmentType = SegmentType.NUMBER
value: float
class FloatVariable(FloatSegment, Variable):
pass
class IntegerVariable(Variable):
value_type: SegmentType = SegmentType.NUMBER
value: int
class IntegerVariable(IntegerSegment, Variable):
pass
class ObjectVariable(Variable):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Variable]
@property
def text(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
@property
def log(self) -> str:
# TODO: Process variables.
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
# TODO: Use markdown code block
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
def to_object(self):
return {k: v.to_object() for k, v in self.value.items()}
class ObjectVariable(ObjectSegment, Variable):
pass
class ArrayVariable(Variable):
value_type: SegmentType = SegmentType.ARRAY
value: Sequence[Variable]
@property
def markdown(self) -> str:
return '\n'.join(['- ' + item.markdown for item in self.value])
def to_object(self):
return [v.to_object() for v in self.value]
class ArrayVariable(ArraySegment, Variable):
pass
class FileVariable(Variable):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar
@property
def markdown(self) -> str:
return self.value.to_markdown()
class FileVariable(FileSegment, Variable):
pass
class SecretVariable(StringVariable):

View File

@ -4,7 +4,7 @@ from typing import Any, Union
from typing_extensions import deprecated
from core.app.segments import Variable, factory
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
@ -33,7 +33,7 @@ class VariablePool:
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
# TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs
@ -67,15 +67,15 @@ class VariablePool:
if value is None:
return
if not isinstance(value, Variable):
v = factory.build_anonymous_variable(value)
else:
if isinstance(value, Segment):
v = value
else:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
Retrieves the value from the variable pool based on the given selector.

View File

@ -126,6 +126,7 @@ class ToolNode(BaseNode):
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable':
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
else:
segment_group = parser.convert_template(

View File

@ -5,7 +5,8 @@ from core.app.segments import (
ArrayVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
NoneSegment,
ObjectSegment,
ObjectVariable,
SecretVariable,
SegmentType,
@ -139,10 +140,10 @@ def test_variable_to_object():
def test_build_a_object_variable_with_none_value():
var = factory.build_anonymous_variable(
var = factory.build_segment(
{
'key1': None,
}
)
assert isinstance(var, ObjectVariable)
assert isinstance(var.value['key1'], NoneVariable)
assert isinstance(var, ObjectSegment)
assert isinstance(var.value['key1'], NoneSegment)