mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 20:15:58 +08:00
refactor: tool entities
This commit is contained in:
parent
63206a7967
commit
65a4cb769b
@ -1,23 +1,36 @@
|
||||
import enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolParameterOption,
|
||||
ToolProviderIdentity,
|
||||
)
|
||||
|
||||
|
||||
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||
"""
|
||||
Inherits from ToolProviderIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentStrategyParameter(BaseModel):
|
||||
class AgentStrategyParameter(PluginParameter):
|
||||
class AgentStrategyParameterType(enum.StrEnum):
|
||||
"""
|
||||
Keep all the types from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
@ -26,7 +39,6 @@ class AgentStrategyParameter(BaseModel):
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
@ -34,131 +46,15 @@ class AgentStrategyParameter(BaseModel):
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
if self in {
|
||||
AgentStrategyParameter.AgentStrategyParameterType.SECRET_INPUT,
|
||||
AgentStrategyParameter.AgentStrategyParameterType.SELECT,
|
||||
}:
|
||||
return "string"
|
||||
return self.value
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any, /):
|
||||
try:
|
||||
match self:
|
||||
case (
|
||||
AgentStrategyParameter.AgentStrategyParameterType.STRING
|
||||
| AgentStrategyParameter.AgentStrategyParameterType.SECRET_INPUT
|
||||
| AgentStrategyParameter.AgentStrategyParameterType.SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
case AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case AgentStrategyParameter.AgentStrategyParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case (
|
||||
AgentStrategyParameter.AgentStrategyParameterType.SYSTEM_FILES
|
||||
| AgentStrategyParameter.AgentStrategyParameterType.FILES
|
||||
):
|
||||
if not isinstance(value, list):
|
||||
return [value]
|
||||
return value
|
||||
case AgentStrategyParameter.AgentStrategyParameterType.FILE:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(
|
||||
"This parameter only accepts one file but got multiple files while invoking."
|
||||
)
|
||||
else:
|
||||
return value[0]
|
||||
return value
|
||||
case (
|
||||
AgentStrategyParameter.AgentStrategyParameterType.TOOL_SELECTOR
|
||||
| AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR
|
||||
| AgentStrategyParameter.AgentStrategyParameterType.APP_SELECTOR
|
||||
| AgentStrategyParameter.AgentStrategyParameterType.TOOLS_SELECTOR
|
||||
):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("The selector must be a dictionary.")
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.")
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||
scope: str | None = None
|
||||
required: Optional[bool] = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
options: list[ToolParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def transform_options(cls, v):
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
type: AgentStrategyParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
get a simple tool parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param type: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
if options:
|
||||
option_objs = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
type=type,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
)
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class AgentStrategyProviderEntity(BaseModel):
|
||||
@ -167,6 +63,10 @@ class AgentStrategyProviderEntity(BaseModel):
|
||||
|
||||
|
||||
class AgentStrategyIdentity(ToolIdentity):
|
||||
"""
|
||||
Inherits from ToolIdentity, without any additional fields.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ from core.agent.entities import AgentInvokeMessage
|
||||
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||
from core.agent.strategy.base import BaseAgentStrategy
|
||||
from core.plugin.manager.agent import PluginAgentManager
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
|
||||
|
||||
class PluginAgentStrategy(BaseAgentStrategy):
|
||||
@ -24,6 +24,14 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Initialize the parameters for the agent strategy.
|
||||
"""
|
||||
for parameter in self.declaration.parameters:
|
||||
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||
return params
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
@ -37,8 +45,8 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
"""
|
||||
manager = PluginAgentManager()
|
||||
|
||||
# convert agent parameters with File type to PluginFileEntity
|
||||
params = PluginTool._transform_image_parameters(params)
|
||||
initialized_params = self.initialize_parameters(params)
|
||||
params = convert_parameters_to_plugin_format(initialized_params)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
|
@ -12,7 +12,7 @@ class CommonParameterType(Enum):
|
||||
SYSTEM_FILES = "system-files"
|
||||
BOOLEAN = "boolean"
|
||||
APP_SELECTOR = "app-selector"
|
||||
TOOL_SELECTOR = "tool-selector"
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
TOOLS_SELECTOR = "array[tools]"
|
||||
|
||||
|
@ -146,7 +146,6 @@ class BasicProviderConfig(BaseModel):
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@ -92,6 +92,17 @@ class File(BaseModel):
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dify_model_identity": FILE_MODEL_IDENTITY,
|
||||
"mime_type": self.mime_type,
|
||||
"filename": self.filename,
|
||||
"extension": self.extension,
|
||||
"size": self.size,
|
||||
"type": self.type,
|
||||
"url": self.generate_url(),
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_after(self):
|
||||
match self.transfer_method:
|
||||
|
146
api/core/plugin/entities/parameters.py
Normal file
146
api/core/plugin/entities/parameters.py
Normal file
@ -0,0 +1,146 @@
|
||||
import enum
|
||||
from typing import Any, Optional, Union
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.parameter_entities import CommonParameterType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class PluginParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class PluginParameterType(enum.StrEnum):
|
||||
"""
|
||||
all available parameter types
|
||||
"""
|
||||
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
||||
|
||||
class PluginParameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
scope: str | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
options: list[PluginParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def transform_options(cls, v):
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return v
|
||||
|
||||
|
||||
def as_normal_type(typ: enum.Enum):
|
||||
if typ.value in {
|
||||
PluginParameterType.SECRET_INPUT,
|
||||
PluginParameterType.SELECT,
|
||||
}:
|
||||
return "string"
|
||||
return typ.value
|
||||
|
||||
|
||||
def cast_parameter_value(typ: enum.Enum, value: Any, /):
|
||||
try:
|
||||
match typ.value:
|
||||
case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT:
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case PluginParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case PluginParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case PluginParameterType.SYSTEM_FILES | PluginParameterType.FILES:
|
||||
if not isinstance(value, list):
|
||||
return [value]
|
||||
return value
|
||||
case PluginParameterType.FILE:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError("This parameter only accepts one file but got multiple files while invoking.")
|
||||
else:
|
||||
return value[0]
|
||||
return value
|
||||
case PluginParameterType.MODEL_SELECTOR | PluginParameterType.APP_SELECTOR:
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("The selector must be a dictionary.")
|
||||
return value
|
||||
case PluginParameterType.TOOLS_SELECTOR:
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("The tools selector must be a list.")
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.")
|
||||
|
||||
|
||||
def init_frontend_parameter(rule: PluginParameter, type: enum.Enum, value: Any):
|
||||
"""
|
||||
init frontend parameter by rule
|
||||
"""
|
||||
parameter_value = value
|
||||
if not parameter_value and parameter_value != 0:
|
||||
# get default value
|
||||
parameter_value = rule.default
|
||||
if not parameter_value and rule.required:
|
||||
raise ValueError(f"tool parameter {rule.name} not found in tool config")
|
||||
|
||||
if type == PluginParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = [x.value for x in rule.options]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(f"tool parameter {rule.name} value {parameter_value} not in options {options}")
|
||||
|
||||
return cast_parameter_value(type, parameter_value)
|
21
api/core/plugin/utils/converter.py
Normal file
21
api/core/plugin/utils/converter.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import File
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
|
||||
|
||||
def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
for parameter_name, parameter in parameters.items():
|
||||
if isinstance(parameter, File):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
elif isinstance(parameter, ToolSelector):
|
||||
parameters[parameter_name] = parameter.to_plugin_parameter()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, ToolSelector) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
parameters[parameter_name].append(p.to_plugin_parameter())
|
||||
return parameters
|
@ -6,9 +6,10 @@ from core.file.enums import FileType
|
||||
from core.file.file_manager import download
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@ -51,7 +52,7 @@ class ASRTool(BuiltinTool):
|
||||
|
||||
options = []
|
||||
for provider, model in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
|
||||
parameters.append(
|
||||
|
@ -4,9 +4,10 @@ from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from services.model_provider_service import ModelProviderService
|
||||
|
||||
|
||||
@ -54,7 +55,7 @@ class TTSTool(BuiltinTool):
|
||||
|
||||
options = []
|
||||
for provider, model, voices in self.get_available_models():
|
||||
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
|
||||
options.append(option)
|
||||
parameters.append(
|
||||
ToolParameter(
|
||||
@ -63,7 +64,7 @@ class TTSTool(BuiltinTool):
|
||||
type=ToolParameter.ToolParameterType.SELECT,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
options=[
|
||||
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
PluginParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
|
||||
for voice in voices
|
||||
],
|
||||
)
|
||||
|
1
api/core/tools/entities/constants.py
Normal file
1
api/core/tools/entities/constants.py
Normal file
@ -0,0 +1 @@
|
||||
TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__"
|
@ -1,18 +1 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.constants import FILE_MODEL_IDENTITY
|
||||
from core.file.enums import FileType
|
||||
|
||||
|
||||
class PluginFileEntity(BaseModel):
|
||||
"""
|
||||
File entity for plugin tool.
|
||||
"""
|
||||
|
||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||
mime_type: str | None
|
||||
filename: str | None
|
||||
extension: str | None
|
||||
size: int | None
|
||||
type: FileType
|
||||
url: str
|
||||
|
@ -5,14 +5,17 @@ from typing import Any, Mapping, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
|
||||
|
||||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(Enum):
|
||||
@ -204,139 +207,51 @@ class ToolInvokeMessageBinary(BaseModel):
|
||||
file_var: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ToolParameterOption(BaseModel):
|
||||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
class ToolParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def transform_id_to_str(cls, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
class ToolParameterType(enum.StrEnum):
|
||||
STRING = CommonParameterType.STRING.value
|
||||
NUMBER = CommonParameterType.NUMBER.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
SELECT = CommonParameterType.SELECT.value
|
||||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
if self in {
|
||||
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||
ToolParameter.ToolParameterType.SELECT,
|
||||
}:
|
||||
return "string"
|
||||
return self.value
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any, /):
|
||||
try:
|
||||
match self:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if value is None:
|
||||
return False
|
||||
elif isinstance(value, str):
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
else:
|
||||
return value if isinstance(value, bool) else bool(value)
|
||||
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
if isinstance(value, int | float):
|
||||
return value
|
||||
elif isinstance(value, str) and value:
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILES:
|
||||
if not isinstance(value, list):
|
||||
return [value]
|
||||
return value
|
||||
case ToolParameter.ToolParameterType.FILE:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(
|
||||
"This parameter only accepts one file but got multiple files while invoking."
|
||||
)
|
||||
else:
|
||||
return value[0]
|
||||
return value
|
||||
case (
|
||||
ToolParameter.ToolParameterType.TOOL_SELECTOR
|
||||
| ToolParameter.ToolParameterType.MODEL_SELECTOR
|
||||
| ToolParameter.ToolParameterType.APP_SELECTOR
|
||||
):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("The selector must be a dictionary.")
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.")
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
FORM = "form" # should be set before invoking tool
|
||||
LLM = "llm" # will be set by LLM
|
||||
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
label: I18nObject = Field(..., description="The label presented to the user")
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
default: Optional[Union[float, int, str]] = None
|
||||
min: Optional[Union[float, int]] = None
|
||||
max: Optional[Union[float, int]] = None
|
||||
options: list[ToolParameterOption] = Field(default_factory=list)
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def transform_options(cls, v):
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
llm_description: str,
|
||||
type: ToolParameterType,
|
||||
typ: ToolParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "ToolParameter":
|
||||
@ -352,22 +267,27 @@ class ToolParameter(BaseModel):
|
||||
# convert options to ToolParameterOption
|
||||
if options:
|
||||
option_objs = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
human_description=I18nObject(en_US="", zh_Hans=""),
|
||||
type=type,
|
||||
type=typ,
|
||||
form=cls.ToolParameterForm.LLM,
|
||||
llm_description=llm_description,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class ToolProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
@ -412,7 +332,7 @@ class ToolEntity(BaseModel):
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -479,3 +399,24 @@ class ToolInvokeFrom(Enum):
|
||||
WORKFLOW = "workflow"
|
||||
AGENT = "agent"
|
||||
PLUGIN = "plugin"
|
||||
|
||||
|
||||
class ToolSelector(BaseModel):
|
||||
dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
class Parameter(BaseModel):
|
||||
name: str = Field(..., description="The name of the parameter")
|
||||
type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter")
|
||||
required: bool = Field(..., description="Whether the parameter is required")
|
||||
description: str = Field(..., description="The description of the parameter")
|
||||
default: Optional[Union[int, float, str]] = None
|
||||
options: Optional[list[PluginParameterOption]] = None
|
||||
|
||||
provider_id: str = Field(..., description="The id of the provider")
|
||||
tool_name: str = Field(..., description="The name of the tool")
|
||||
tool_description: str = Field(..., description="The description of the tool")
|
||||
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
|
||||
tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm")
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return self.model_dump()
|
||||
|
@ -2,11 +2,10 @@ from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.plugin.manager.tool import PluginToolManager
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.file_entities import PluginFileEntity
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
from models.model import File
|
||||
|
||||
|
||||
class PluginTool(Tool):
|
||||
@ -27,40 +26,6 @@ class PluginTool(Tool):
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
@classmethod
|
||||
def _transform_image_parameters(cls, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
for parameter_name, parameter in parameters.items():
|
||||
if isinstance(parameter, File):
|
||||
url = parameter.generate_url()
|
||||
if url is None:
|
||||
raise ValueError(f"File {parameter.id} does not have a valid URL")
|
||||
parameters[parameter_name] = PluginFileEntity(
|
||||
url=url,
|
||||
mime_type=parameter.mime_type,
|
||||
type=parameter.type,
|
||||
filename=parameter.filename,
|
||||
extension=parameter.extension,
|
||||
size=parameter.size,
|
||||
).model_dump()
|
||||
elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter):
|
||||
parameters[parameter_name] = []
|
||||
for p in parameter:
|
||||
assert isinstance(p, File)
|
||||
url = p.generate_url()
|
||||
if url is None:
|
||||
raise ValueError(f"File {p.id} does not have a valid URL")
|
||||
parameters[parameter_name].append(
|
||||
PluginFileEntity(
|
||||
url=url,
|
||||
mime_type=p.mime_type,
|
||||
type=p.type,
|
||||
filename=p.filename,
|
||||
extension=p.extension,
|
||||
size=p.size,
|
||||
).model_dump()
|
||||
)
|
||||
return parameters
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
@ -71,8 +36,7 @@ class PluginTool(Tool):
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
manager = PluginToolManager()
|
||||
|
||||
# convert tool parameters with File type to PluginFileEntity
|
||||
tool_parameters = self._transform_image_parameters(tool_parameters)
|
||||
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
|
||||
|
||||
yield from manager.invoke(
|
||||
tenant_id=self.tenant_id,
|
||||
|
@ -285,28 +285,6 @@ class ToolManager:
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
|
||||
|
||||
@classmethod
|
||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
|
||||
"""
|
||||
init runtime parameter
|
||||
"""
|
||||
parameter_value = parameters.get(parameter_rule.name)
|
||||
if not parameter_value and parameter_value != 0:
|
||||
# get default value
|
||||
parameter_value = parameter_rule.default
|
||||
if not parameter_value and parameter_rule.required:
|
||||
raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
|
||||
|
||||
if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = [x.value for x in parameter_rule.options]
|
||||
if parameter_value is not None and parameter_value not in options:
|
||||
raise ValueError(
|
||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
||||
)
|
||||
|
||||
return parameter_rule.type.cast_value(parameter_value)
|
||||
|
||||
@classmethod
|
||||
def get_agent_tool_runtime(
|
||||
cls,
|
||||
@ -343,7 +321,7 @@ class ToolManager:
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
# decrypt runtime parameters
|
||||
@ -356,9 +334,6 @@ class ToolManager:
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
if not tool_entity.runtime:
|
||||
raise Exception("tool missing runtime")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@ -388,7 +363,7 @@ class ToolManager:
|
||||
for parameter in parameters:
|
||||
# save tool parameter to tool entity memory
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
|
||||
value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
# decrypt runtime parameters
|
||||
@ -403,9 +378,6 @@ class ToolManager:
|
||||
if runtime_parameters:
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
if not tool_runtime.runtime:
|
||||
raise Exception("tool missing runtime")
|
||||
|
||||
tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_runtime
|
||||
|
||||
@ -434,12 +406,9 @@ class ToolManager:
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, tool_parameters)
|
||||
value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name))
|
||||
runtime_parameters[parameter.name] = value
|
||||
|
||||
if not tool_entity.runtime:
|
||||
raise Exception("tool missing runtime")
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@ -608,9 +577,8 @@ class ToolManager:
|
||||
tool_provider_id = GenericProviderID(db_provider.provider)
|
||||
db_provider.provider = tool_provider_id.to_string()
|
||||
|
||||
find_db_builtin_provider = lambda provider: next(
|
||||
(x for x in db_builtin_providers if x.provider == provider), None
|
||||
)
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
|
@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.plugin.entities.parameters import PluginParameterOption
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
ToolIdentity,
|
||||
ToolParameter,
|
||||
ToolParameterOption,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
@ -116,7 +116,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in variable.options
|
||||
]
|
||||
|
||||
|
@ -117,7 +117,7 @@ class AgentNode(ToolNode):
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
if agent_input.type == "variable":
|
||||
variable = variable_pool.get(agent_input.value)
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {agent_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
|
@ -2,6 +2,7 @@ from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
|
||||
@ -20,8 +21,21 @@ class AgentEntity(BaseModel):
|
||||
|
||||
for key in values.data.get("agent_configurations", {}):
|
||||
value = values.data.get("agent_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
if isinstance(value, dict):
|
||||
# convert dict to ToolSelector
|
||||
return ToolSelector(**value)
|
||||
elif isinstance(value, ToolSelector):
|
||||
return value
|
||||
elif isinstance(value, list):
|
||||
# convert list[ToolSelector] to ToolSelector
|
||||
if all(isinstance(val, dict) for val in value):
|
||||
return [ToolSelector(**val) for val in value]
|
||||
elif all(isinstance(val, ToolSelector) for val in value):
|
||||
return value
|
||||
else:
|
||||
raise ValueError("value must be a list of ToolSelector")
|
||||
else:
|
||||
raise ValueError("value must be a dictionary or ToolSelector")
|
||||
|
||||
return value
|
||||
|
||||
@ -29,7 +43,7 @@ class AgentEntity(BaseModel):
|
||||
class AgentNodeData(BaseNodeData, AgentEntity):
|
||||
class AgentInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@ -45,8 +59,23 @@ class AgentNodeData(BaseNodeData, AgentEntity):
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
elif typ == "constant":
|
||||
if isinstance(value, list):
|
||||
# convert dict to ToolSelector
|
||||
if all(isinstance(val, dict) for val in value):
|
||||
return value
|
||||
elif all(isinstance(val, ToolSelector) for val in value):
|
||||
return value
|
||||
else:
|
||||
raise ValueError("value must be a list of ToolSelector")
|
||||
elif isinstance(value, dict):
|
||||
# convert dict to ToolSelector
|
||||
return ToolSelector(**value)
|
||||
elif isinstance(value, ToolSelector):
|
||||
return value
|
||||
else:
|
||||
raise ValueError("value must be a list of ToolSelector")
|
||||
|
||||
return typ
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
Loading…
x
Reference in New Issue
Block a user