mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 20:59:00 +08:00
refactor(api/core/app/segments): Support more kinds of Segments. (#6706)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
6b50bb0fe6
commit
c6996a48a4
@ -1,5 +1,14 @@
|
|||||||
from .segment_group import SegmentGroup
|
from .segment_group import SegmentGroup
|
||||||
from .segments import NoneSegment, Segment
|
from .segments import (
|
||||||
|
ArraySegment,
|
||||||
|
FileSegment,
|
||||||
|
FloatSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
Segment,
|
||||||
|
StringSegment,
|
||||||
|
)
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
from .variables import (
|
from .variables import (
|
||||||
ArrayVariable,
|
ArrayVariable,
|
||||||
@ -27,4 +36,10 @@ __all__ = [
|
|||||||
'Segment',
|
'Segment',
|
||||||
'NoneSegment',
|
'NoneSegment',
|
||||||
'NoneVariable',
|
'NoneVariable',
|
||||||
|
'IntegerSegment',
|
||||||
|
'FloatSegment',
|
||||||
|
'ObjectSegment',
|
||||||
|
'ArraySegment',
|
||||||
|
'FileSegment',
|
||||||
|
'StringSegment',
|
||||||
]
|
]
|
||||||
|
@ -3,15 +3,20 @@ from typing import Any
|
|||||||
|
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
from .segments import Segment, StringSegment
|
from .segments import (
|
||||||
|
ArraySegment,
|
||||||
|
FileSegment,
|
||||||
|
FloatSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
Segment,
|
||||||
|
StringSegment,
|
||||||
|
)
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
from .variables import (
|
from .variables import (
|
||||||
ArrayVariable,
|
|
||||||
FileVariable,
|
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
NoneVariable,
|
|
||||||
ObjectVariable,
|
|
||||||
SecretVariable,
|
SecretVariable,
|
||||||
StringVariable,
|
StringVariable,
|
||||||
Variable,
|
Variable,
|
||||||
@ -39,29 +44,23 @@ def build_variable_from_mapping(m: Mapping[str, Any], /) -> Variable:
|
|||||||
raise ValueError(f'not supported value type {value_type}')
|
raise ValueError(f'not supported value type {value_type}')
|
||||||
|
|
||||||
|
|
||||||
def build_anonymous_variable(value: Any, /) -> Variable:
|
|
||||||
if value is None:
|
|
||||||
return NoneVariable(name='anonymous')
|
|
||||||
if isinstance(value, str):
|
|
||||||
return StringVariable(name='anonymous', value=value)
|
|
||||||
if isinstance(value, int):
|
|
||||||
return IntegerVariable(name='anonymous', value=value)
|
|
||||||
if isinstance(value, float):
|
|
||||||
return FloatVariable(name='anonymous', value=value)
|
|
||||||
if isinstance(value, dict):
|
|
||||||
# TODO: Limit the depth of the object
|
|
||||||
obj = {k: build_anonymous_variable(v) for k, v in value.items()}
|
|
||||||
return ObjectVariable(name='anonymous', value=obj)
|
|
||||||
if isinstance(value, list):
|
|
||||||
# TODO: Limit the depth of the array
|
|
||||||
elements = [build_anonymous_variable(v) for v in value]
|
|
||||||
return ArrayVariable(name='anonymous', value=elements)
|
|
||||||
if isinstance(value, FileVar):
|
|
||||||
return FileVariable(name='anonymous', value=value)
|
|
||||||
raise ValueError(f'not supported value {value}')
|
|
||||||
|
|
||||||
|
|
||||||
def build_segment(value: Any, /) -> Segment:
|
def build_segment(value: Any, /) -> Segment:
|
||||||
|
if value is None:
|
||||||
|
return NoneSegment()
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return StringSegment(value=value)
|
return StringSegment(value=value)
|
||||||
|
if isinstance(value, int):
|
||||||
|
return IntegerSegment(value=value)
|
||||||
|
if isinstance(value, float):
|
||||||
|
return FloatSegment(value=value)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# TODO: Limit the depth of the object
|
||||||
|
obj = {k: build_segment(v) for k, v in value.items()}
|
||||||
|
return ObjectSegment(value=obj)
|
||||||
|
if isinstance(value, list):
|
||||||
|
# TODO: Limit the depth of the array
|
||||||
|
elements = [build_segment(v) for v in value]
|
||||||
|
return ArraySegment(value=elements)
|
||||||
|
if isinstance(value, FileVar):
|
||||||
|
return FileSegment(value=value)
|
||||||
raise ValueError(f'not supported value {value}')
|
raise ValueError(f'not supported value {value}')
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from core.app.segments import SegmentGroup, factory
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
from . import SegmentGroup, factory
|
||||||
|
|
||||||
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||||
|
|
||||||
|
|
||||||
@ -14,4 +15,4 @@ def convert_template(*, template: str, variable_pool: VariablePool):
|
|||||||
segments.append(value)
|
segments.append(value)
|
||||||
else:
|
else:
|
||||||
segments.append(factory.build_segment(part))
|
segments.append(factory.build_segment(part))
|
||||||
return SegmentGroup(segments=segments)
|
return SegmentGroup(value=segments)
|
||||||
|
@ -1,19 +1,22 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from .segments import Segment
|
from .segments import Segment
|
||||||
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
class SegmentGroup(BaseModel):
|
class SegmentGroup(Segment):
|
||||||
segments: list[Segment]
|
value_type: SegmentType = SegmentType.GROUP
|
||||||
|
value: list[Segment]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self):
|
def text(self):
|
||||||
return ''.join([segment.text for segment in self.segments])
|
return ''.join([segment.text for segment in self.value])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log(self):
|
def log(self):
|
||||||
return ''.join([segment.log for segment in self.segments])
|
return ''.join([segment.log for segment in self.value])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self):
|
def markdown(self):
|
||||||
return ''.join([segment.markdown for segment in self.segments])
|
return ''.join([segment.markdown for segment in self.value])
|
||||||
|
|
||||||
|
def to_object(self):
|
||||||
|
return [segment.to_object() for segment in self.value]
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
|
import json
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
|
from core.file.file_obj import FileVar
|
||||||
|
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
@ -57,3 +61,57 @@ class NoneSegment(Segment):
|
|||||||
class StringSegment(Segment):
|
class StringSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.STRING
|
value_type: SegmentType = SegmentType.STRING
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
class FloatSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.NUMBER
|
||||||
|
value: float
|
||||||
|
|
||||||
|
|
||||||
|
class IntegerSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.NUMBER
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.OBJECT
|
||||||
|
value: Mapping[str, Segment]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
# TODO: Process variables.
|
||||||
|
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
# TODO: Process variables.
|
||||||
|
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
# TODO: Use markdown code block
|
||||||
|
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def to_object(self):
|
||||||
|
return {k: v.to_object() for k, v in self.value.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class ArraySegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.ARRAY
|
||||||
|
value: Sequence[Segment]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return '\n'.join(['- ' + item.markdown for item in self.value])
|
||||||
|
|
||||||
|
def to_object(self):
|
||||||
|
return [v.to_object() for v in self.value]
|
||||||
|
|
||||||
|
|
||||||
|
class FileSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.FILE
|
||||||
|
# TODO: embed FileVar in this model.
|
||||||
|
value: FileVar
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return self.value.to_markdown()
|
||||||
|
@ -9,3 +9,5 @@ class SegmentType(str, Enum):
|
|||||||
ARRAY = 'array'
|
ARRAY = 'array'
|
||||||
OBJECT = 'object'
|
OBJECT = 'object'
|
||||||
FILE = 'file'
|
FILE = 'file'
|
||||||
|
|
||||||
|
GROUP = 'group'
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
import json
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from core.file.file_obj import FileVar
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
|
||||||
from .segments import NoneSegment, Segment, StringSegment
|
from .segments import (
|
||||||
|
ArraySegment,
|
||||||
|
FileSegment,
|
||||||
|
FloatSegment,
|
||||||
|
IntegerSegment,
|
||||||
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
|
Segment,
|
||||||
|
StringSegment,
|
||||||
|
)
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
@ -27,59 +33,24 @@ class StringVariable(StringSegment, Variable):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FloatVariable(Variable):
|
class FloatVariable(FloatSegment, Variable):
|
||||||
value_type: SegmentType = SegmentType.NUMBER
|
pass
|
||||||
value: float
|
|
||||||
|
|
||||||
|
|
||||||
class IntegerVariable(Variable):
|
class IntegerVariable(IntegerSegment, Variable):
|
||||||
value_type: SegmentType = SegmentType.NUMBER
|
pass
|
||||||
value: int
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectVariable(Variable):
|
class ObjectVariable(ObjectSegment, Variable):
|
||||||
value_type: SegmentType = SegmentType.OBJECT
|
pass
|
||||||
value: Mapping[str, Variable]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text(self) -> str:
|
|
||||||
# TODO: Process variables.
|
|
||||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def log(self) -> str:
|
|
||||||
# TODO: Process variables.
|
|
||||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def markdown(self) -> str:
|
|
||||||
# TODO: Use markdown code block
|
|
||||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
def to_object(self):
|
|
||||||
return {k: v.to_object() for k, v in self.value.items()}
|
|
||||||
|
|
||||||
|
|
||||||
class ArrayVariable(Variable):
|
class ArrayVariable(ArraySegment, Variable):
|
||||||
value_type: SegmentType = SegmentType.ARRAY
|
pass
|
||||||
value: Sequence[Variable]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def markdown(self) -> str:
|
|
||||||
return '\n'.join(['- ' + item.markdown for item in self.value])
|
|
||||||
|
|
||||||
def to_object(self):
|
|
||||||
return [v.to_object() for v in self.value]
|
|
||||||
|
|
||||||
|
|
||||||
class FileVariable(Variable):
|
class FileVariable(FileSegment, Variable):
|
||||||
value_type: SegmentType = SegmentType.FILE
|
pass
|
||||||
# TODO: embed FileVar in this model.
|
|
||||||
value: FileVar
|
|
||||||
|
|
||||||
@property
|
|
||||||
def markdown(self) -> str:
|
|
||||||
return self.value.to_markdown()
|
|
||||||
|
|
||||||
|
|
||||||
class SecretVariable(StringVariable):
|
class SecretVariable(StringVariable):
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from core.app.segments import Variable, factory
|
from core.app.segments import Segment, Variable, factory
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.entities.node_entities import SystemVariable
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class VariablePool:
|
|||||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||||
# elements of the selector except the first one.
|
# elements of the selector except the first one.
|
||||||
self._variable_dictionary: dict[str, dict[int, Variable]] = defaultdict(dict)
|
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
|
||||||
|
|
||||||
# TODO: This user inputs is not used for pool.
|
# TODO: This user inputs is not used for pool.
|
||||||
self.user_inputs = user_inputs
|
self.user_inputs = user_inputs
|
||||||
@ -67,15 +67,15 @@ class VariablePool:
|
|||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(value, Variable):
|
if isinstance(value, Segment):
|
||||||
v = factory.build_anonymous_variable(value)
|
|
||||||
else:
|
|
||||||
v = value
|
v = value
|
||||||
|
else:
|
||||||
|
v = factory.build_segment(value)
|
||||||
|
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
self._variable_dictionary[selector[0]][hash_key] = v
|
self._variable_dictionary[selector[0]][hash_key] = v
|
||||||
|
|
||||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
||||||
"""
|
"""
|
||||||
Retrieves the value from the variable pool based on the given selector.
|
Retrieves the value from the variable pool based on the given selector.
|
||||||
|
|
||||||
|
@ -126,6 +126,7 @@ class ToolNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
tool_input = node_data.tool_parameters[parameter_name]
|
tool_input = node_data.tool_parameters[parameter_name]
|
||||||
if tool_input.type == 'variable':
|
if tool_input.type == 'variable':
|
||||||
|
# TODO: check if the variable exists in the variable pool
|
||||||
parameter_value = variable_pool.get(tool_input.value).value
|
parameter_value = variable_pool.get(tool_input.value).value
|
||||||
else:
|
else:
|
||||||
segment_group = parser.convert_template(
|
segment_group = parser.convert_template(
|
||||||
|
@ -5,7 +5,8 @@ from core.app.segments import (
|
|||||||
ArrayVariable,
|
ArrayVariable,
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
NoneVariable,
|
NoneSegment,
|
||||||
|
ObjectSegment,
|
||||||
ObjectVariable,
|
ObjectVariable,
|
||||||
SecretVariable,
|
SecretVariable,
|
||||||
SegmentType,
|
SegmentType,
|
||||||
@ -139,10 +140,10 @@ def test_variable_to_object():
|
|||||||
|
|
||||||
|
|
||||||
def test_build_a_object_variable_with_none_value():
|
def test_build_a_object_variable_with_none_value():
|
||||||
var = factory.build_anonymous_variable(
|
var = factory.build_segment(
|
||||||
{
|
{
|
||||||
'key1': None,
|
'key1': None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
assert isinstance(var, ObjectVariable)
|
assert isinstance(var, ObjectSegment)
|
||||||
assert isinstance(var.value['key1'], NoneVariable)
|
assert isinstance(var.value['key1'], NoneSegment)
|
Loading…
x
Reference in New Issue
Block a user