From dedc1b0c3a13750b6b7335ebf75cc476c2c525ff Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 12 Dec 2024 19:16:06 +0800 Subject: [PATCH] refactor: agent strategy parameter --- api/core/agent/base_agent_runner.py | 10 +- api/core/agent/plugin_entities.py | 155 +++++++++++++++++- api/core/entities/parameter_entities.py | 1 + api/core/tools/tool_engine.py | 9 +- api/core/tools/tool_manager.py | 4 +- api/core/tools/workflow_as_tool/tool.py | 2 +- api/services/tools/tools_transform_service.py | 2 +- 7 files changed, 168 insertions(+), 15 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 4d7ae33280..6d03b09c87 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -340,7 +340,7 @@ class BaseAgentRunner(AppRunner): if isinstance(tool_input, dict): try: tool_input = json.dumps(tool_input, ensure_ascii=False) - except Exception as e: + except Exception: tool_input = json.dumps(tool_input) updated_agent_thought.tool_input = tool_input @@ -349,7 +349,7 @@ class BaseAgentRunner(AppRunner): if isinstance(observation, dict): try: observation = json.dumps(observation, ensure_ascii=False) - except Exception as e: + except Exception: observation = json.dumps(observation) updated_agent_thought.observation = observation @@ -389,7 +389,7 @@ class BaseAgentRunner(AppRunner): if isinstance(tool_invoke_meta, dict): try: tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False) - except Exception as e: + except Exception: tool_invoke_meta = json.dumps(tool_invoke_meta) updated_agent_thought.tool_meta_str = tool_invoke_meta @@ -433,11 +433,11 @@ class BaseAgentRunner(AppRunner): tool_call_response: list[ToolPromptMessage] = [] try: tool_inputs = json.loads(agent_thought.tool_input) - except Exception as e: + except Exception: tool_inputs = {tool: {} for tool in tools} try: tool_responses = json.loads(agent_thought.observation) - except Exception as e: + except Exception: tool_responses = dict.fromkeys(tools, agent_thought.observation) for tool in tools: diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index d24c5e8336..c115d6a7af 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -1,17 +1,164 @@ -from typing import Optional +import enum +from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from core.entities.parameter_entities import CommonParameterType from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderIdentity +from core.tools.entities.tool_entities import ( + ToolIdentity, + ToolParameterOption, + ToolProviderIdentity, +) class AgentStrategyProviderIdentity(ToolProviderIdentity): pass -class AgentStrategyParameter(ToolParameter): - pass +class AgentStrategyParameter(BaseModel): + class AgentStrategyParameterType(enum.StrEnum): + STRING = CommonParameterType.STRING.value + NUMBER = CommonParameterType.NUMBER.value + BOOLEAN = CommonParameterType.BOOLEAN.value + SELECT = CommonParameterType.SELECT.value + SECRET_INPUT = CommonParameterType.SECRET_INPUT.value + FILE = CommonParameterType.FILE.value + FILES = CommonParameterType.FILES.value + APP_SELECTOR = CommonParameterType.APP_SELECTOR.value + TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value + MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value + TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + + # deprecated, should not use. + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value + + def as_normal_type(self): + if self in { + AgentStrategyParameter.AgentStrategyParameterType.SECRET_INPUT, + AgentStrategyParameter.AgentStrategyParameterType.SELECT, + }: + return "string" + return self.value + + def cast_value(self, value: Any, /): + try: + match self: + case ( + AgentStrategyParameter.AgentStrategyParameterType.STRING + | AgentStrategyParameter.AgentStrategyParameterType.SECRET_INPUT + | AgentStrategyParameter.AgentStrategyParameterType.SELECT + ): + if value is None: + return "" + else: + return value if isinstance(value, str) else str(value) + + case AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN: + if value is None: + return False + elif isinstance(value, str): + # Allowed YAML boolean value strings: https://yaml.org/type/bool.html + # and also '0' for False and '1' for True + match value.lower(): + case "true" | "yes" | "y" | "1": + return True + case "false" | "no" | "n" | "0": + return False + case _: + return bool(value) + else: + return value if isinstance(value, bool) else bool(value) + + case AgentStrategyParameter.AgentStrategyParameterType.NUMBER: + if isinstance(value, int | float): + return value + elif isinstance(value, str) and value: + if "." in value: + return float(value) + else: + return int(value) + case ( + AgentStrategyParameter.AgentStrategyParameterType.SYSTEM_FILES + | AgentStrategyParameter.AgentStrategyParameterType.FILES + ): + if not isinstance(value, list): + return [value] + return value + case AgentStrategyParameter.AgentStrategyParameterType.FILE: + if isinstance(value, list): + if len(value) != 1: + raise ValueError( + "This parameter only accepts one file but got multiple files while invoking." + ) + else: + return value[0] + return value + case ( + AgentStrategyParameter.AgentStrategyParameterType.TOOL_SELECTOR + | AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR + | AgentStrategyParameter.AgentStrategyParameterType.APP_SELECTOR + | AgentStrategyParameter.AgentStrategyParameterType.TOOLS_SELECTOR + ): + if not isinstance(value, dict): + raise ValueError("The selector must be a dictionary.") + return value + case _: + return str(value) + + except Exception: + raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.") + + name: str = Field(..., description="The name of the parameter") + label: I18nObject = Field(..., description="The label presented to the user") + placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user") + type: AgentStrategyParameterType = Field(..., description="The type of the parameter") + scope: str | None = None + required: Optional[bool] = False + default: Optional[Union[float, int, str]] = None + min: Optional[Union[float, int]] = None + max: Optional[Union[float, int]] = None + options: list[ToolParameterOption] = Field(default_factory=list) + + @field_validator("options", mode="before") + @classmethod + def transform_options(cls, v): + if not isinstance(v, list): + return [] + return v + + @classmethod + def get_simple_instance( + cls, + name: str, + type: AgentStrategyParameterType, + required: bool, + options: Optional[list[str]] = None, + ): + """ + get a simple tool parameter + + :param name: the name of the parameter + :param llm_description: the description presented to the LLM + :param type: the type of the parameter + :param required: if the parameter is required + :param options: the options of the parameter + """ + # convert options to ToolParameterOption + if options: + option_objs = [ + ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options + ] + else: + option_objs = [] + return cls( + name=name, + label=I18nObject(en_US="", zh_Hans=""), + placeholder=None, + type=type, + required=required, + options=option_objs, + ) class AgentStrategyProviderEntity(BaseModel): diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 990e84867e..7aec38389c 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -14,6 +14,7 @@ class CommonParameterType(Enum): APP_SELECTOR = "app-selector" TOOL_SELECTOR = "tool-selector" MODEL_SELECTOR = "model-selector" + TOOLS_SELECTOR = "array[tools]" class AppSelectorScope(Enum): diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index c27c149bfd..0556384bd2 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -14,7 +14,12 @@ from core.file import FileType from core.file.models import FileTransferMethod from core.ops.ops_trace_manager import TraceQueueManager from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolInvokeMeta, + ToolParameter, +) from core.tools.errors import ( ToolEngineInvokeError, ToolInvokeError, @@ -66,7 +71,7 @@ class ToolEngine: else: try: tool_parameters = json.loads(tool_parameters) - except Exception as e: + except Exception: pass if not isinstance(tool_parameters, dict): raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index c20abc45e4..032c9f2454 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -548,7 +548,7 @@ class ToolManager: cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label yield provider - except Exception as e: + except Exception: logger.exception(f"load builtin provider {provider}") continue # set builtin providers loaded @@ -670,7 +670,7 @@ class ToolManager: workflow_provider_controllers.append( ToolTransformService.workflow_provider_to_controller(db_provider=provider) ) - except Exception as e: + except Exception: # app has been deleted pass diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 2f88ff10ad..2998fb8ce2 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -202,7 +202,7 @@ class WorkflowTool(Tool): file_dict["url"] = file.generate_url() files.append(file_dict) - except Exception as e: + except Exception: logger.exception(f"Failed to transform file {file}") else: parameters_result[parameter.name] = tool_parameters.get(parameter.name) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 885a7f20e2..7bd1d711f3 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -197,7 +197,7 @@ class ToolTransformService: raise ValueError("user not found") username = user.name - except Exception as e: + except Exception: logger.exception(f"failed to get user name for api provider {db_provider.id}") # add provider into providers credentials = db_provider.credentials