mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 12:29:00 +08:00
refactor(api/core/app/segments): Support more kinds of Segments. (#6706)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
6b50bb0fe6
commit
c6996a48a4
@ -1,5 +1,14 @@
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import NoneSegment, Segment
|
||||
from .segments import (
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayVariable,
|
||||
@ -27,4 +36,10 @@ __all__ = [
|
||||
'Segment',
|
||||
'NoneSegment',
|
||||
'NoneVariable',
|
||||
'IntegerSegment',
|
||||
'FloatSegment',
|
||||
'ObjectSegment',
|
||||
'ArraySegment',
|
||||
'FileSegment',
|
||||
'StringSegment',
|
||||
]
|
||||
|
@ -3,15 +3,20 @@ from typing import Any
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .segments import Segment, StringSegment
|
||||
from .segments import (
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from .types import SegmentType
|
||||
from .variables import (
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
@ -39,29 +44,23 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
|
||||
raise ValueError(f'not supported value type {value_type}')
|
||||
|
||||
|
||||
def build_anonymous_variable(value: Any, /) -> Variable:
|
||||
if value is None:
|
||||
return NoneVariable(name='anonymous')
|
||||
if isinstance(value, str):
|
||||
return StringVariable(name='anonymous', value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerVariable(name='anonymous', value=value)
|
||||
if isinstance(value, float):
|
||||
return FloatVariable(name='anonymous', value=value)
|
||||
if isinstance(value, dict):
|
||||
# TODO: Limit the depth of the object
|
||||
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
|
||||
return ObjectVariable(name='anonymous', value=obj)
|
||||
if isinstance(value, list):
|
||||
# TODO: Limit the depth of the array
|
||||
elements = [build_anonymous_variable(v) for v in value]
|
||||
return ArrayVariable(name='anonymous', value=elements)
|
||||
if isinstance(value, FileVar):
|
||||
return FileVariable(name='anonymous', value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
if value is None:
|
||||
return NoneSegment()
|
||||
if isinstance(value, str):
|
||||
return StringSegment(value=value)
|
||||
if isinstance(value, int):
|
||||
return IntegerSegment(value=value)
|
||||
if isinstance(value, float):
|
||||
return FloatSegment(value=value)
|
||||
if isinstance(value, dict):
|
||||
# TODO: Limit the depth of the object
|
||||
obj = {k: build_segment(v) for k, v in value.items()}
|
||||
return ObjectSegment(value=obj)
|
||||
if isinstance(value, list):
|
||||
# TODO: Limit the depth of the array
|
||||
elements = [build_segment(v) for v in value]
|
||||
return ArraySegment(value=elements)
|
||||
if isinstance(value, FileVar):
|
||||
return FileSegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
|
@ -1,8 +1,9 @@
|
||||
import re
|
||||
|
||||
from core.app.segments import SegmentGroup, factory
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from . import SegmentGroup, factory
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||
|
||||
|
||||
@ -14,4 +15,4 @@ def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
return SegmentGroup(segments=segments)
|
||||
return SegmentGroup(value=segments)
|
||||
|
@ -1,19 +1,22 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .segments import Segment
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class SegmentGroup(BaseModel):
|
||||
segments: list[Segment]
|
||||
class SegmentGroup(Segment):
|
||||
value_type: SegmentType = SegmentType.GROUP
|
||||
value: list[Segment]
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return ''.join([segment.text for segment in self.segments])
|
||||
return ''.join([segment.text for segment in self.value])
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return ''.join([segment.log for segment in self.segments])
|
||||
return ''.join([segment.log for segment in self.value])
|
||||
|
||||
@property
|
||||
def markdown(self):
|
||||
return ''.join([segment.markdown for segment in self.segments])
|
||||
return ''.join([segment.markdown for segment in self.value])
|
||||
|
||||
def to_object(self):
|
||||
return [segment.to_object() for segment in self.value]
|
||||
|
@ -1,7 +1,11 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
@ -57,3 +61,57 @@ class NoneSegment(Segment):
|
||||
class StringSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.STRING
|
||||
value: str
|
||||
|
||||
class FloatSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: float
|
||||
|
||||
|
||||
class IntegerSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: int
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Segment]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
# TODO: Use markdown code block
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
def to_object(self):
|
||||
return {k: v.to_object() for k, v in self.value.items()}
|
||||
|
||||
|
||||
class ArraySegment(Segment):
|
||||
value_type: SegmentType = SegmentType.ARRAY
|
||||
value: Sequence[Segment]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
|
||||
def to_object(self):
|
||||
return [v.to_object() for v in self.value]
|
||||
|
||||
|
||||
class FileSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
|
@ -9,3 +9,5 @@ class SegmentType(str, Enum):
|
||||
ARRAY = 'array'
|
||||
OBJECT = 'object'
|
||||
FILE = 'file'
|
||||
|
||||
GROUP = 'group'
|
||||
|
@ -1,12 +1,18 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.helper import encrypter
|
||||
|
||||
from .segments import NoneSegment, Segment, StringSegment
|
||||
from .segments import (
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
FloatSegment,
|
||||
IntegerSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
StringSegment,
|
||||
)
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
@ -27,59 +33,24 @@ class StringVariable(StringSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FloatVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: float
|
||||
class FloatVariable(FloatSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class IntegerVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.NUMBER
|
||||
value: int
|
||||
class IntegerVariable(IntegerSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Variable]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
# TODO: Process variables.
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
# TODO: Use markdown code block
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
|
||||
def to_object(self):
|
||||
return {k: v.to_object() for k, v in self.value.items()}
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.ARRAY
|
||||
value: Sequence[Variable]
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||
|
||||
def to_object(self):
|
||||
return [v.to_object() for v in self.value]
|
||||
class ArrayVariable(ArraySegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class FileVariable(Variable):
|
||||
value_type: SegmentType = SegmentType.FILE
|
||||
# TODO: embed FileVar in this model.
|
||||
value: FileVar
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return self.value.to_markdown()
|
||||
class FileVariable(FileSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
|
@ -4,7 +4,7 @@ from typing import Any, Union
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import Variable, factory
|
||||
from core.app.segments import Segment, Variable, factory
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
@ -33,7 +33,7 @@ class VariablePool:
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
|
||||
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
|
||||
|
||||
# TODO: This user inputs is not used for pool.
|
||||
self.user_inputs = user_inputs
|
||||
@ -67,15 +67,15 @@ class VariablePool:
|
||||
if value is None:
|
||||
return
|
||||
|
||||
if not isinstance(value, Variable):
|
||||
v = factory.build_anonymous_variable(value)
|
||||
else:
|
||||
if isinstance(value, Segment):
|
||||
v = value
|
||||
else:
|
||||
v = factory.build_segment(value)
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
||||
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||
"""
|
||||
Retrieves the value from the variable pool based on the given selector.
|
||||
|
||||
|
@ -126,6 +126,7 @@ class ToolNode(BaseNode):
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == 'variable':
|
||||
# TODO: check if the variable exists in the variable pool
|
||||
parameter_value = variable_pool.get(tool_input.value).value
|
||||
else:
|
||||
segment_group = parser.convert_template(
|
||||
|
@ -5,7 +5,8 @@ from core.app.segments import (
|
||||
ArrayVariable,
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
NoneVariable,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
SegmentType,
|
||||
@ -139,10 +140,10 @@ def test_variable_to_object():
|
||||
|
||||
|
||||
def test_build_a_object_variable_with_none_value():
|
||||
var = factory.build_anonymous_variable(
|
||||
var = factory.build_segment(
|
||||
{
|
||||
'key1': None,
|
||||
}
|
||||
)
|
||||
assert isinstance(var, ObjectVariable)
|
||||
assert isinstance(var.value['key1'], NoneVariable)
|
||||
assert isinstance(var, ObjectSegment)
|
||||
assert isinstance(var.value['key1'], NoneSegment)
|
Loading…
x
Reference in New Issue
Block a user