From 65a4cb769be2be2d0f8be660f7e67edc3152db0a Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Fri, 13 Dec 2024 19:50:54 +0800 Subject: [PATCH] refactor: tool entities --- api/core/agent/plugin_entities.py | 150 +++------------ api/core/agent/strategy/plugin.py | 14 +- api/core/entities/parameter_entities.py | 2 +- api/core/entities/provider_entities.py | 1 - api/core/file/models.py | 13 +- api/core/plugin/entities/parameters.py | 146 ++++++++++++++ api/core/plugin/utils/converter.py | 21 ++ .../builtin_tool/providers/audio/tools/asr.py | 5 +- .../builtin_tool/providers/audio/tools/tts.py | 7 +- api/core/tools/entities/constants.py | 1 + api/core/tools/entities/file_entities.py | 17 -- api/core/tools/entities/tool_entities.py | 181 ++++++------------ api/core/tools/plugin_tool/tool.py | 40 +--- api/core/tools/tool_manager.py | 42 +--- api/core/tools/workflow_as_tool/provider.py | 4 +- api/core/workflow/nodes/agent/agent_node.py | 2 +- api/core/workflow/nodes/agent/entities.py | 39 +++- 17 files changed, 329 insertions(+), 356 deletions(-) create mode 100644 api/core/plugin/entities/parameters.py create mode 100644 api/core/plugin/utils/converter.py create mode 100644 api/core/tools/entities/constants.py diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index c115d6a7af..92bd5500ef 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,23 +1,36 @@ import enum -from typing import Any, Optional, Union +from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from core.entities.parameter_entities import CommonParameterType +from core.plugin.entities.parameters import ( + PluginParameter, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ToolIdentity, - ToolParameterOption, ToolProviderIdentity, ) class AgentStrategyProviderIdentity(ToolProviderIdentity): + """ + Inherits from ToolProviderIdentity, without any additional fields. + """ + pass -class AgentStrategyParameter(BaseModel): +class AgentStrategyParameter(PluginParameter): class AgentStrategyParameterType(enum.StrEnum): + """ + Keep all the types from PluginParameterType + """ + STRING = CommonParameterType.STRING.value NUMBER = CommonParameterType.NUMBER.value BOOLEAN = CommonParameterType.BOOLEAN.value @@ -26,7 +39,6 @@ class AgentStrategyParameter(BaseModel): FILE = CommonParameterType.FILE.value FILES = CommonParameterType.FILES.value APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value @@ -34,131 +46,15 @@ class AgentStrategyParameter(BaseModel): SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value def as_normal_type(self): - if self in { - AgentStrategyParameter.AgentStrategyParameterType.SECRET_INPUT, - AgentStrategyParameter.AgentStrategyParameterType.SELECT, - }: - return "string" - return self.value + return as_normal_type(self) - def cast_value(self, value: Any, /): - try: - match self: - case ( - AgentStrategyParameter.AgentStrategyParameterType.STRING - | AgentStrategyParameter.AgentStrategyParameterType.SECRET_INPUT - | AgentStrategyParameter.AgentStrategyParameterType.SELECT - ): - if value is None: - return "" - else: - return value if isinstance(value, str) else str(value) + def cast_value(self, value: Any): + return cast_parameter_value(self, value) - case AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case "true" | "yes" | "y" | "1": - return True - case "false" | "no" | "n" | "0": - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) - - case AgentStrategyParameter.AgentStrategyParameterType.NUMBER: - if isinstance(value, int | float): - return value - elif isinstance(value, str) and value: - if "." in value: - return float(value) - else: - return int(value) - case ( - AgentStrategyParameter.AgentStrategyParameterType.SYSTEM_FILES - | AgentStrategyParameter.AgentStrategyParameterType.FILES - ): - if not isinstance(value, list): - return [value] - return value - case AgentStrategyParameter.AgentStrategyParameterType.FILE: - if isinstance(value, list): - if len(value) != 1: - raise ValueError( - "This parameter only accepts one file but got multiple files while invoking." - ) - else: - return value[0] - return value - case ( - AgentStrategyParameter.AgentStrategyParameterType.TOOL_SELECTOR - | AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR - | AgentStrategyParameter.AgentStrategyParameterType.APP_SELECTOR - | AgentStrategyParameter.AgentStrategyParameterType.TOOLS_SELECTOR - ): - if not isinstance(value, dict): - raise ValueError("The selector must be a dictionary.") - return value - case _: - return str(value) - - except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.") - - name: str = Field(..., description="The name of the parameter") - label: I18nObject = Field(..., description="The label presented to the user") - placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") type: AgentStrategyParameterType = Field(..., description="The type of the parameter") - scope: str | None = None - required: Optional[bool] = False - default: Optional[Union[float, int, str]] = None - min: Optional[Union[float, int]] = None - max: Optional[Union[float, int]] = None - options: list[ToolParameterOption] = Field(default_factory=list) - @field_validator("options", mode="before") - @classmethod - def transform_options(cls, v): - if not isinstance(v, list): - return [] - return v - - @classmethod - def get_simple_instance( - cls, - name: str, - type: AgentStrategyParameterType, - required: bool, - options: Optional[list[str]] = None, - ): - """ - get a simple tool parameter - - :param name: the name of the parameter - :param llm_description: the description presented to the LLM - :param type: the type of the parameter - :param required: if the parameter is required - :param options: the options of the parameter - """ - # convert options to ToolParameterOption - if options: - option_objs = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options - ] - else: - option_objs = [] - return cls( - name=name, - label=I18nObject(en_US="", zh_Hans=""), - placeholder=None, - type=type, - required=required, - options=option_objs, - ) + def init_frontend_parameter(self, value: Any): + return init_frontend_parameter(self, self.type, value) class AgentStrategyProviderEntity(BaseModel): @@ -167,6 +63,10 @@ class AgentStrategyProviderEntity(BaseModel): class AgentStrategyIdentity(ToolIdentity): + """ + Inherits from ToolIdentity, without any additional fields. + """ + pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 7cbd4503a7..979096c154 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -4,7 +4,7 @@ from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.strategy.base import BaseAgentStrategy from core.plugin.manager.agent import PluginAgentManager -from core.tools.plugin_tool.tool import PluginTool +from core.plugin.utils.converter import convert_parameters_to_plugin_format class PluginAgentStrategy(BaseAgentStrategy): @@ -24,6 +24,14 @@ class PluginAgentStrategy(BaseAgentStrategy): def get_parameters(self) -> Sequence[AgentStrategyParameter]: return self.declaration.parameters + def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]: + """ + Initialize the parameters for the agent strategy. + """ + for parameter in self.declaration.parameters: + params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name)) + return params + def _invoke( self, params: dict[str, Any], @@ -37,8 +45,8 @@ class PluginAgentStrategy(BaseAgentStrategy): """ manager = PluginAgentManager() - # convert agent parameters with File type to PluginFileEntity - params = PluginTool._transform_image_parameters(params) + initialized_params = self.initialize_parameters(params) + params = convert_parameters_to_plugin_format(initialized_params) yield from manager.invoke( tenant_id=self.tenant_id, diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 7aec38389c..03a6c270c6 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -12,7 +12,7 @@ class CommonParameterType(Enum): SYSTEM_FILES = "system-files" BOOLEAN = "boolean" APP_SELECTOR = "app-selector" - TOOL_SELECTOR = "tool-selector" + # TOOL_SELECTOR = "tool-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index b5de05d2fb..e04e2a42fd 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -146,7 +146,6 @@ class BasicProviderConfig(BaseModel): BOOLEAN = CommonParameterType.BOOLEAN.value APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value - TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value @classmethod def value_of(cls, value: str) -> "ProviderConfig.Type": diff --git a/api/core/file/models.py b/api/core/file/models.py index 3e7e189c62..85eb4a4823 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel, Field, model_validator @@ -92,6 +92,17 @@ class File(BaseModel): tool_file_id=self.related_id, extension=self.extension ) + def to_plugin_parameter(self) -> dict[str, Any]: + return { + "dify_model_identity": FILE_MODEL_IDENTITY, + "mime_type": self.mime_type, + "filename": self.filename, + "extension": self.extension, + "size": self.size, + "type": self.type, + "url": self.generate_url(), + } + @model_validator(mode="after") def validate_after(self): match self.transfer_method: diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py new file mode 100644 index 0000000000..868e2fc1e4 --- /dev/null +++ b/api/core/plugin/entities/parameters.py @@ -0,0 +1,146 @@ +import enum +from typing import Any, Optional, Union +from pydantic import BaseModel, Field, field_validator + +from core.entities.parameter_entities import CommonParameterType +from core.tools.entities.common_entities import I18nObject + + +class PluginParameterOption(BaseModel): + value: str = Field(..., description="The value of the option") + label: I18nObject = Field(..., description="The label of the option") + + @field_validator("value", mode="before") + @classmethod + def transform_id_to_str(cls, value) -> str: + if not isinstance(value, str): + return str(value) + else: + return value + + +class PluginParameterType(enum.StrEnum): + """ + all available parameter types + """ + + STRING = CommonParameterType.STRING.value + NUMBER = CommonParameterType.NUMBER.value + BOOLEAN = CommonParameterType.BOOLEAN.value + SELECT = CommonParameterType.SELECT.value + SECRET_INPUT = CommonParameterType.SECRET_INPUT.value + FILE = CommonParameterType.FILE.value + FILES = CommonParameterType.FILES.value + APP_SELECTOR = CommonParameterType.APP_SELECTOR.value + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + + # deprecated, should not use. + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + + +class PluginParameter(BaseModel): + name: str = Field(..., description="The name of the parameter") + label: I18nObject = Field(..., description="The label presented to the user") + placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") + scope: str | None = None + required: bool = False + default: Optional[Union[float, int, str]] = None + min: Optional[Union[float, int]] = None + max: Optional[Union[float, int]] = None + options: list[PluginParameterOption] = Field(default_factory=list) + + @field_validator("options", mode="before") + @classmethod + def transform_options(cls, v): + if not isinstance(v, list): + return [] + return v + + +def as_normal_type(typ: enum.Enum): + if typ.value in { + PluginParameterType.SECRET_INPUT, + PluginParameterType.SELECT, + }: + return "string" + return typ.value + + +def cast_parameter_value(typ: enum.Enum, value: Any, /): + try: + match typ.value: + case PluginParameterType.STRING | PluginParameterType.SECRET_INPUT | PluginParameterType.SELECT: + if value is None: + return "" + else: + return value if isinstance(value, str) else str(value) + + case PluginParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case PluginParameterType.NUMBER: + if isinstance(value, int | float): + return value + elif isinstance(value, str) and value: + if "." in value: + return float(value) + else: + return int(value) + case PluginParameterType.SYSTEM_FILES | PluginParameterType.FILES: + if not isinstance(value, list): + return [value] + return value + case PluginParameterType.FILE: + if isinstance(value, list): + if len(value) != 1: + raise ValueError("This parameter only accepts one file but got multiple files while invoking.") + else: + return value[0] + return value + case PluginParameterType.MODEL_SELECTOR | PluginParameterType.APP_SELECTOR: + if not isinstance(value, dict): + raise ValueError("The selector must be a dictionary.") + return value + case PluginParameterType.TOOLS_SELECTOR: + if not isinstance(value, list): + raise ValueError("The tools selector must be a list.") + return value + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") + + +def init_frontend_parameter(rule: PluginParameter, type: enum.Enum, value: Any): + """ + init frontend parameter by rule + """ + parameter_value = value + if not parameter_value and parameter_value != 0: + # get default value + parameter_value = rule.default + if not parameter_value and rule.required: + raise ValueError(f"tool parameter {rule.name} not found in tool config") + + if type == PluginParameterType.SELECT: + # check if tool_parameter_config in options + options = [x.value for x in rule.options] + if parameter_value is not None and parameter_value not in options: + raise ValueError(f"tool parameter {rule.name} value {parameter_value} not in options {options}") + + return cast_parameter_value(type, parameter_value) diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py new file mode 100644 index 0000000000..6876285b31 --- /dev/null +++ b/api/core/plugin/utils/converter.py @@ -0,0 +1,21 @@ +from typing import Any + +from core.file.models import File +from core.tools.entities.tool_entities import ToolSelector + + +def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: + for parameter_name, parameter in parameters.items(): + if isinstance(parameter, File): + parameters[parameter_name] = parameter.to_plugin_parameter() + elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter): + parameters[parameter_name] = [] + for p in parameter: + parameters[parameter_name].append(p.to_plugin_parameter()) + elif isinstance(parameter, ToolSelector): + parameters[parameter_name] = parameter.to_plugin_parameter() + elif isinstance(parameter, list) and all(isinstance(p, ToolSelector) for p in parameter): + parameters[parameter_name] = [] + for p in parameter: + parameters[parameter_name].append(p.to_plugin_parameter()) + return parameters diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 6af0430d01..f517ac63f9 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -6,9 +6,10 @@ from core.file.enums import FileType from core.file.file_manager import download from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from services.model_provider_service import ModelProviderService @@ -51,7 +52,7 @@ class ASRTool(BuiltinTool): options = [] for provider, model in self.get_available_models(): - option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) + option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) options.append(option) parameters.append( diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 9d083b35b3..3801dcab6a 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -4,9 +4,10 @@ from typing import Any from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from services.model_provider_service import ModelProviderService @@ -54,7 +55,7 @@ class TTSTool(BuiltinTool): options = [] for provider, model, voices in self.get_available_models(): - option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) + option = PluginParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})")) options.append(option) parameters.append( ToolParameter( @@ -63,7 +64,7 @@ class TTSTool(BuiltinTool): type=ToolParameter.ToolParameterType.SELECT, form=ToolParameter.ToolParameterForm.FORM, options=[ - ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name"))) + PluginParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name"))) for voice in voices ], ) diff --git a/api/core/tools/entities/constants.py b/api/core/tools/entities/constants.py new file mode 100644 index 0000000000..199c9f0d53 --- /dev/null +++ b/api/core/tools/entities/constants.py @@ -0,0 +1 @@ +TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__" diff --git a/api/core/tools/entities/file_entities.py b/api/core/tools/entities/file_entities.py index 9a38eadda5..8b13789179 100644 --- a/api/core/tools/entities/file_entities.py +++ b/api/core/tools/entities/file_entities.py @@ -1,18 +1 @@ -from pydantic import BaseModel -from core.file.constants import FILE_MODEL_IDENTITY -from core.file.enums import FileType - - -class PluginFileEntity(BaseModel): - """ - File entity for plugin tool. - """ - - dify_model_identity: str = FILE_MODEL_IDENTITY - mime_type: str | None - filename: str | None - extension: str | None - size: int | None - type: FileType - url: str diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5be5643c9a..d53c70c4c9 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -5,14 +5,17 @@ from typing import Any, Mapping, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator -from core.entities.parameter_entities import ( - AppSelectorScope, - CommonParameterType, - ModelSelectorScope, - ToolSelectorScope, -) from core.entities.provider_entities import ProviderConfig +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) from core.tools.entities.common_entities import I18nObject +from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY class ToolLabelEnum(Enum): @@ -204,139 +207,51 @@ class ToolInvokeMessageBinary(BaseModel): file_var: Optional[dict[str, Any]] = None -class ToolParameterOption(BaseModel): - value: str = Field(..., description="The value of the option") - label: I18nObject = Field(..., description="The label of the option") +class ToolParameter(PluginParameter): + """ + Overrides type + """ - @field_validator("value", mode="before") - @classmethod - def transform_id_to_str(cls, value) -> str: - if not isinstance(value, str): - return str(value) - else: - return value - - -class ToolParameter(BaseModel): class ToolParameterType(enum.StrEnum): - STRING = CommonParameterType.STRING.value - NUMBER = CommonParameterType.NUMBER.value - BOOLEAN = CommonParameterType.BOOLEAN.value - SELECT = CommonParameterType.SELECT.value - SECRET_INPUT = CommonParameterType.SECRET_INPUT.value - FILE = CommonParameterType.FILE.value - FILES = CommonParameterType.FILES.value - APP_SELECTOR = CommonParameterType.APP_SELECTOR.value - TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value - MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value + """ + removes TOOLS_SELECTOR from PluginParameterType + """ + + STRING = PluginParameterType.STRING.value + NUMBER = PluginParameterType.NUMBER.value + BOOLEAN = PluginParameterType.BOOLEAN.value + SELECT = PluginParameterType.SELECT.value + SECRET_INPUT = PluginParameterType.SECRET_INPUT.value + FILE = PluginParameterType.FILE.value + FILES = PluginParameterType.FILES.value + APP_SELECTOR = PluginParameterType.APP_SELECTOR.value + MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value # deprecated, should not use. - SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value def as_normal_type(self): - if self in { - ToolParameter.ToolParameterType.SECRET_INPUT, - ToolParameter.ToolParameterType.SELECT, - }: - return "string" - return self.value + return as_normal_type(self) - def cast_value(self, value: Any, /): - try: - match self: - case ( - ToolParameter.ToolParameterType.STRING - | ToolParameter.ToolParameterType.SECRET_INPUT - | ToolParameter.ToolParameterType.SELECT - ): - if value is None: - return "" - else: - return value if isinstance(value, str) else str(value) - - case ToolParameter.ToolParameterType.BOOLEAN: - if value is None: - return False - elif isinstance(value, str): - # Allowed YAML boolean value strings: https://yaml.org/type/bool.html - # and also '0' for False and '1' for True - match value.lower(): - case "true" | "yes" | "y" | "1": - return True - case "false" | "no" | "n" | "0": - return False - case _: - return bool(value) - else: - return value if isinstance(value, bool) else bool(value) - - case ToolParameter.ToolParameterType.NUMBER: - if isinstance(value, int | float): - return value - elif isinstance(value, str) and value: - if "." in value: - return float(value) - else: - return int(value) - case ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILES: - if not isinstance(value, list): - return [value] - return value - case ToolParameter.ToolParameterType.FILE: - if isinstance(value, list): - if len(value) != 1: - raise ValueError( - "This parameter only accepts one file but got multiple files while invoking." - ) - else: - return value[0] - return value - case ( - ToolParameter.ToolParameterType.TOOL_SELECTOR - | ToolParameter.ToolParameterType.MODEL_SELECTOR - | ToolParameter.ToolParameterType.APP_SELECTOR - ): - if not isinstance(value, dict): - raise ValueError("The selector must be a dictionary.") - return value - case _: - return str(value) - - except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.") + def cast_value(self, value: Any): + return cast_parameter_value(self, value) class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool FORM = "form" # should be set before invoking tool LLM = "llm" # will be set by LLM - name: str = Field(..., description="The name of the parameter") - label: I18nObject = Field(..., description="The label presented to the user") - human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") - placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") type: ToolParameterType = Field(..., description="The type of the parameter") - scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None + human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None - required: Optional[bool] = False - default: Optional[Union[float, int, str]] = None - min: Optional[Union[float, int]] = None - max: Optional[Union[float, int]] = None - options: list[ToolParameterOption] = Field(default_factory=list) - - @field_validator("options", mode="before") - @classmethod - def transform_options(cls, v): - if not isinstance(v, list): - return [] - return v @classmethod def get_simple_instance( cls, name: str, llm_description: str, - type: ToolParameterType, + typ: ToolParameterType, required: bool, options: Optional[list[str]] = None, ) -> "ToolParameter": @@ -352,22 +267,27 @@ class ToolParameter(BaseModel): # convert options to ToolParameterOption if options: option_objs = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + for option in options ] else: option_objs = [] + return cls( name=name, label=I18nObject(en_US="", zh_Hans=""), placeholder=None, human_description=I18nObject(en_US="", zh_Hans=""), - type=type, + type=typ, form=cls.ToolParameterForm.LLM, llm_description=llm_description, required=required, options=option_objs, ) + def init_frontend_parameter(self, value: Any): + return init_frontend_parameter(self, self.type, value) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") @@ -412,7 +332,7 @@ class ToolEntity(BaseModel): class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity - plugin_id: Optional[str] = Field(None, description="The id of the plugin") + plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) @@ -479,3 +399,24 @@ class ToolInvokeFrom(Enum): WORKFLOW = "workflow" AGENT = "agent" PLUGIN = "plugin" + + +class ToolSelector(BaseModel): + dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY + + class Parameter(BaseModel): + name: str = Field(..., description="The name of the parameter") + type: ToolParameter.ToolParameterType = Field(..., description="The type of the parameter") + required: bool = Field(..., description="Whether the parameter is required") + description: str = Field(..., description="The description of the parameter") + default: Optional[Union[int, float, str]] = None + options: Optional[list[PluginParameterOption]] = None + + provider_id: str = Field(..., description="The id of the provider") + tool_name: str = Field(..., description="The name of the tool") + tool_description: str = Field(..., description="The description of the tool") + tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") + tool_parameters: Mapping[str, Parameter] = Field(..., description="Parameters, type llm") + + def to_plugin_parameter(self) -> dict[str, Any]: + return self.model_dump() diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 9ac2f60ab2..8c6dd8894b 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -2,11 +2,10 @@ from collections.abc import Generator from typing import Any, Optional from core.plugin.manager.tool import PluginToolManager +from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.file_entities import PluginFileEntity from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType -from models.model import File class PluginTool(Tool): @@ -27,40 +26,6 @@ class PluginTool(Tool): def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.PLUGIN - @classmethod - def _transform_image_parameters(cls, parameters: dict[str, Any]) -> dict[str, Any]: - for parameter_name, parameter in parameters.items(): - if isinstance(parameter, File): - url = parameter.generate_url() - if url is None: - raise ValueError(f"File {parameter.id} does not have a valid URL") - parameters[parameter_name] = PluginFileEntity( - url=url, - mime_type=parameter.mime_type, - type=parameter.type, - filename=parameter.filename, - extension=parameter.extension, - size=parameter.size, - ).model_dump() - elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter): - parameters[parameter_name] = [] - for p in parameter: - assert isinstance(p, File) - url = p.generate_url() - if url is None: - raise ValueError(f"File {p.id} does not have a valid URL") - parameters[parameter_name].append( - PluginFileEntity( - url=url, - mime_type=p.mime_type, - type=p.type, - filename=p.filename, - extension=p.extension, - size=p.size, - ).model_dump() - ) - return parameters - def _invoke( self, user_id: str, @@ -71,8 +36,7 @@ class PluginTool(Tool): ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() - # convert tool parameters with File type to PluginFileEntity - tool_parameters = self._transform_image_parameters(tool_parameters) + tool_parameters = convert_parameters_to_plugin_format(tool_parameters) yield from manager.invoke( tenant_id=self.tenant_id, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 032c9f2454..7e67c06873 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -285,28 +285,6 @@ class ToolManager: else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") - @classmethod - def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict): - """ - init runtime parameter - """ - parameter_value = parameters.get(parameter_rule.name) - if not parameter_value and parameter_value != 0: - # get default value - parameter_value = parameter_rule.default - if not parameter_value and parameter_rule.required: - raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config") - - if parameter_rule.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = [x.value for x in parameter_rule.options] - if parameter_value is not None and parameter_value not in options: - raise ValueError( - f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}" - ) - - return parameter_rule.type.cast_value(parameter_value) - @classmethod def get_agent_tool_runtime( cls, @@ -343,7 +321,7 @@ class ToolManager: if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory - value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters) + value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) runtime_parameters[parameter.name] = value # decrypt runtime parameters @@ -356,9 +334,6 @@ class ToolManager: ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - if not tool_entity.runtime: - raise Exception("tool missing runtime") - tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -388,7 +363,7 @@ class ToolManager: for parameter in parameters: # save tool parameter to tool entity memory if parameter.form == ToolParameter.ToolParameterForm.FORM: - value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations) + value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) runtime_parameters[parameter.name] = value # decrypt runtime parameters @@ -403,9 +378,6 @@ class ToolManager: if runtime_parameters: runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - if not tool_runtime.runtime: - raise Exception("tool missing runtime") - tool_runtime.runtime.runtime_parameters.update(runtime_parameters) return tool_runtime @@ -434,12 +406,9 @@ class ToolManager: for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: # save tool parameter to tool entity memory - value = cls._init_runtime_parameter(parameter, tool_parameters) + value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name)) runtime_parameters[parameter.name] = value - if not tool_entity.runtime: - raise Exception("tool missing runtime") - tool_entity.runtime.runtime_parameters.update(runtime_parameters) return tool_entity @@ -608,9 +577,8 @@ class ToolManager: tool_provider_id = GenericProviderID(db_provider.provider) db_provider.provider = tool_provider_id.to_string() - find_db_builtin_provider = lambda provider: next( - (x for x in db_builtin_providers if x.provider == provider), None - ) + def find_db_builtin_provider(provider): + return next((x for x in db_builtin_providers if x.provider == provider), None) # append builtin providers for provider in builtin_providers: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index e80e286e12..c40ea0a0b0 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -5,6 +5,7 @@ from pydantic import Field from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from core.plugin.entities.parameters import PluginParameterOption from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject @@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import ( ToolEntity, ToolIdentity, ToolParameter, - ToolParameterOption, ToolProviderEntity, ToolProviderIdentity, ToolProviderType, @@ -116,7 +116,7 @@ class WorkflowToolProviderController(ToolProviderController): if variable.type == VariableEntityType.SELECT and variable.options: options = [ - ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) + PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in variable.options ] diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 597a5d905e..5347df99b9 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -117,7 +117,7 @@ class AgentNode(ToolNode): continue agent_input = node_data.agent_parameters[parameter_name] if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) + variable = variable_pool.get(agent_input.value) # type: ignore if variable is None: raise ValueError(f"Variable {agent_input.value} does not exist") parameter_value = variable.value diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index ba7bba8252..66fb8773f9 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -2,6 +2,7 @@ from typing import Any, Literal, Union from pydantic import BaseModel, ValidationInfo, field_validator +from core.tools.entities.tool_entities import ToolSelector from core.workflow.nodes.base.entities import BaseNodeData @@ -20,8 +21,21 @@ class AgentEntity(BaseModel): for key in values.data.get("agent_configurations", {}): value = values.data.get("agent_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") + if isinstance(value, dict): + # convert dict to ToolSelector + return ToolSelector(**value) + elif isinstance(value, ToolSelector): + return value + elif isinstance(value, list): + # convert list[ToolSelector] to ToolSelector + if all(isinstance(val, dict) for val in value): + return [ToolSelector(**val) for val in value] + elif all(isinstance(val, ToolSelector) for val in value): + return value + else: + raise ValueError("value must be a list of ToolSelector") + else: + raise ValueError("value must be a dictionary or ToolSelector") return value @@ -29,7 +43,7 @@ class AgentEntity(BaseModel): class AgentNodeData(BaseNodeData, AgentEntity): class AgentInput(BaseModel): # TODO: check this type - value: Union[Any, list[str]] + value: Union[list[str], list[ToolSelector], Any] type: Literal["mixed", "variable", "constant"] @field_validator("type", mode="before") @@ -45,8 +59,23 @@ class AgentNodeData(BaseNodeData, AgentEntity): for val in value: if not isinstance(val, str): raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, str | int | float | bool): - raise ValueError("value must be a string, int, float, or bool") + elif typ == "constant": + if isinstance(value, list): + # convert dict to ToolSelector + if all(isinstance(val, dict) for val in value): + return value + elif all(isinstance(val, ToolSelector) for val in value): + return value + else: + raise ValueError("value must be a list of ToolSelector") + elif isinstance(value, dict): + # convert dict to ToolSelector + return ToolSelector(**value) + elif isinstance(value, ToolSelector): + return value + else: + raise ValueError("value must be a list of ToolSelector") + return typ agent_parameters: dict[str, AgentInput]