diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index f3c041271c..ca541019d1 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import ( from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_manager import ToolManager +from core.tools.utils.tool_parameter_converter import ToolParameterConverter from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables @@ -186,21 +187,11 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = 'string' + parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) enum = [] - if parameter.type == ToolParameter.ToolParameterType.STRING: - parameter_type = 'string' - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - parameter_type = 'boolean' - elif parameter.type == ToolParameter.ToolParameterType.NUMBER: - parameter_type = 'number' - elif parameter.type == ToolParameter.ToolParameterType.SELECT: - for option in parameter.options: - enum.append(option.value) - parameter_type = 'string' - else: - raise ValueError(f"parameter type {parameter.type} is not supported") - + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] + message_tool.parameters['properties'][parameter.name] = { "type": parameter_type, "description": parameter.llm_description or '', @@ -281,20 +272,10 @@ class BaseAgentRunner(AppRunner): if parameter.form != ToolParameter.ToolParameterForm.LLM: continue - parameter_type = 'string' + parameter_type = ToolParameterConverter.get_parameter_type(parameter.type) enum = [] - if parameter.type == ToolParameter.ToolParameterType.STRING: - parameter_type = 'string' - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - parameter_type = 'boolean' - elif parameter.type == ToolParameter.ToolParameterType.NUMBER: - parameter_type = 'number' - elif parameter.type == ToolParameter.ToolParameterType.SELECT: - for option in parameter.options: - enum.append(option.value) - parameter_type = 'string' - else: - raise ValueError(f"parameter type {parameter.type} is not supported") + if parameter.type == ToolParameter.ToolParameterType.SELECT: + enum = [option.value for option in parameter.options] prompt_tool.parameters['properties'][parameter.name] = { "type": parameter_type, diff --git a/api/core/tools/__init__.py b/api/core/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b7917fca7f..55ef8e8291 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -116,8 +116,9 @@ class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolParameter(BaseModel): - class ToolParameterType(Enum): + class ToolParameterType(str, Enum): STRING = "string" NUMBER = "number" BOOLEAN = "boolean" diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index fdcedec260..a7aa62b1ba 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -12,6 +12,7 @@ from core.tools.errors import ( from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.tool import Tool +from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.tools.utils.yaml_utils import load_yaml_file from core.utils.module_import_helper import load_single_subclass_from_source @@ -200,16 +201,8 @@ class BuiltinToolProviderController(ToolProviderController): # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = parameter_schema.default - # parse default value into the correct type - if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \ - parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - default_value = str(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - default_value = float(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - default_value = bool(default_value) - + default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, + parameter_schema.type) tool_parameters[parameter] = default_value def validate_credentials(self, credentials: dict[str, Any]) -> None: diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index 7c064689d0..ef1ace9c7c 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -11,6 +11,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError from core.tools.tool.tool import Tool +from core.tools.utils.tool_parameter_converter import ToolParameterConverter class ToolProviderController(BaseModel, ABC): @@ -122,17 +123,8 @@ class ToolProviderController(BaseModel, ABC): # the parameter is not set currently, set the default value if needed if parameter_schema.default is not None: - default_value = parameter_schema.default - # parse default value into the correct type - if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \ - parameter_schema.type == ToolParameter.ToolParameterType.SELECT: - default_value = str(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER: - default_value = float(default_value) - elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN: - default_value = bool(default_value) - - tool_parameters[parameter] = default_value + tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default, + parameter_schema.type) def validate_credentials_format(self, credentials: dict[str, Any]) -> None: """ diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 7f2adf68f7..290d80c7d1 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import ( ToolRuntimeVariablePool, ) from core.tools.tool_file_manager import ToolFileManager +from core.tools.utils.tool_parameter_converter import ToolParameterConverter class Tool(BaseModel, ABC): @@ -228,46 +229,8 @@ class Tool(BaseModel, ABC): """ Transform tool parameters type """ - for parameter in self.parameters: - if parameter.name in tool_parameters: - if parameter.type in [ - ToolParameter.ToolParameterType.SECRET_INPUT, - ToolParameter.ToolParameterType.STRING, - ToolParameter.ToolParameterType.SELECT, - ] and not isinstance(tool_parameters[parameter.name], str): - if tool_parameters[parameter.name] is None: - tool_parameters[parameter.name] = '' - else: - tool_parameters[parameter.name] = str(tool_parameters[parameter.name]) - elif parameter.type == ToolParameter.ToolParameterType.NUMBER \ - and not isinstance(tool_parameters[parameter.name], int | float): - if isinstance(tool_parameters[parameter.name], str): - try: - tool_parameters[parameter.name] = int(tool_parameters[parameter.name]) - except ValueError: - tool_parameters[parameter.name] = float(tool_parameters[parameter.name]) - elif isinstance(tool_parameters[parameter.name], bool): - tool_parameters[parameter.name] = int(tool_parameters[parameter.name]) - elif tool_parameters[parameter.name] is None: - tool_parameters[parameter.name] = 0 - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - if not isinstance(tool_parameters[parameter.name], bool): - # check if it is a string - if isinstance(tool_parameters[parameter.name], str): - # check true false - if tool_parameters[parameter.name].lower() in ['true', 'false']: - tool_parameters[parameter.name] = tool_parameters[parameter.name].lower() == 'true' - # check 1 0 - elif tool_parameters[parameter.name] in ['1', '0']: - tool_parameters[parameter.name] = tool_parameters[parameter.name] == '1' - else: - tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) - elif isinstance(tool_parameters[parameter.name], int | float): - tool_parameters[parameter.name] = tool_parameters[parameter.name] != 0 - else: - tool_parameters[parameter.name] = bool(tool_parameters[parameter.name]) - - return tool_parameters + return {p.name: ToolParameterConverter.cast_parameter_by_type(tool_parameters[p.name], p.type) + for p in self.parameters if p.name in tool_parameters} @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5ae5b6622d..9def1f4740 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -11,7 +11,6 @@ from flask import current_app from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools import * from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -31,6 +30,7 @@ from core.tools.utils.configuration import ( ToolConfigurationManager, ToolParameterConfigurationManager, ) +from core.tools.utils.tool_parameter_converter import ToolParameterConverter from core.utils.module_import_helper import load_single_subclass_from_source from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db @@ -214,30 +214,7 @@ class ToolManager: raise ValueError( f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}") - # convert tool parameter config to correct type - try: - if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(parameter_value, int): - parameter_value = parameter_value - elif isinstance(parameter_value, float): - parameter_value = parameter_value - elif isinstance(parameter_value, str): - if '.' in parameter_value: - parameter_value = float(parameter_value) - else: - parameter_value = int(parameter_value) - elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN: - parameter_value = bool(parameter_value) - elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT, - ToolParameter.ToolParameterType.STRING]: - parameter_value = str(parameter_value) - elif parameter_rule.type == ToolParameter.ToolParameterType: - parameter_value = str(parameter_value) - except Exception as e: - raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type") - - return parameter_value + return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type) @classmethod def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: diff --git a/api/core/tools/utils/__init__.py b/api/core/tools/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tools/utils/tool_parameter_converter.py b/api/core/tools/utils/tool_parameter_converter.py new file mode 100644 index 0000000000..55535be930 --- /dev/null +++ b/api/core/tools/utils/tool_parameter_converter.py @@ -0,0 +1,66 @@ +from typing import Any + +from core.tools.entities.tool_entities import ToolParameter + + +class ToolParameterConverter: + @staticmethod + def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str: + match parameter_type: + case ToolParameter.ToolParameterType.STRING \ + | ToolParameter.ToolParameterType.SECRET_INPUT \ + | ToolParameter.ToolParameterType.SELECT: + return 'string' + + case ToolParameter.ToolParameterType.BOOLEAN: + return 'boolean' + + case ToolParameter.ToolParameterType.NUMBER: + return 'number' + + case _: + raise ValueError(f"Unsupported parameter type {parameter_type}") + + @staticmethod + def cast_parameter_by_type(value: Any, parameter_type: str) -> Any: + # convert tool parameter config to correct type + try: + match parameter_type: + case ToolParameter.ToolParameterType.STRING \ + | ToolParameter.ToolParameterType.SECRET_INPUT \ + | ToolParameter.ToolParameterType.SELECT: + if value is None: + return '' + else: + return value if isinstance(value, str) else str(value) + + case ToolParameter.ToolParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case 'true' | 'yes' | 'y' | '1': + return True + case 'false' | 'no' | 'n' | '0': + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case ToolParameter.ToolParameterType.NUMBER: + if isinstance(value, int) | isinstance(value, float): + return value + elif isinstance(value, str): + if '.' in value: + return float(value) + else: + return int(value) + + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.") diff --git a/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py new file mode 100644 index 0000000000..9addeeadca --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_parameter_converter.py @@ -0,0 +1,56 @@ +import pytest + +from core.tools.entities.tool_entities import ToolParameter +from core.tools.utils.tool_parameter_converter import ToolParameterConverter + + +def test_get_parameter_type(): + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean' + assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number' + with pytest.raises(ValueError): + ToolParameterConverter.get_parameter_type('unsupported_type') + + +def test_cast_parameter_by_type(): + # string + assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test' + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1' + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == '' + + # secret input + assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test' + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1' + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == '' + + # select + assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test' + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1' + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == '' + + # boolean + true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something'] + for value in true_values: + assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True + + false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, ''] + for value in false_values: + assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False + + # number + assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1 + assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0 + assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0 + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None + + # unknown + assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1' + assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None