diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 3628992816..ad035e80e3 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -655,18 +655,16 @@ class ToolProviderMCPApi(Resource): parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - return jsonable_encoder( - MCPToolManageService.update_mcp_provider( - tenant_id=current_user.current_tenant_id, - name=args["name"], - server_url=args["server_url"], - icon=args["icon"], - icon_type=args["icon_type"], - icon_background=args["icon_background"], - provider_id=args["provider_id"], - encrypted_credentials={}, - ) + MCPToolManageService.update_mcp_provider( + tenant_id=current_user.current_tenant_id, + name=args["name"], + server_url=args["server_url"], + icon=args["icon"], + icon_type=args["icon_type"], + icon_background=args["icon_background"], + provider_id=args["provider_id"], ) + return {"result": "success"} @setup_required @login_required @@ -675,11 +673,8 @@ class ToolProviderMCPApi(Resource): parser = reqparse.RequestParser() parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - return jsonable_encoder( - MCPToolManageService.delete_mcp_tool( - tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"] - ) - ) + MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]) + return {"result": "success"} class ToolMCPAuthApi(Resource): @@ -739,7 +734,9 @@ class ToolMCPListAllApi(Resource): user = current_user tenant_id = user.current_tenant_id - return jsonable_encoder(MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)) + tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id) + + return [tool.to_dict() for tool in tools] class ToolMCPUpdateApi(Resource): @@ -762,12 +759,16 @@ class ToolMCPTokenApi(Resource): def get(self): parser = reqparse.RequestParser() parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args") - parser.add_argument("server_url", type=str, required=True, nullable=False, location="args") parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args") args = parser.parse_args() + provider = MCPToolManageService.get_mcp_provider_by_provider_id( + args["provider_id"], current_user.current_tenant_id + ) + if not provider: + raise ValueError("provider not found") return auth( OAuthClientProvider(args["provider_id"], current_user.current_tenant_id), - server_url=args["server_url"], + provider.server_url, authorization_code=args["authorization_code"], ) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 6998e4d29a..1bb41906f9 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -161,10 +161,14 @@ class BaseAgentRunner(AppRunner): if parameter.type == ToolParameter.ToolParameterType.SELECT: enum = [option.value for option in parameter.options] if parameter.options else [] - message_tool.parameters["properties"][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or "", - } + message_tool.parameters["properties"][parameter.name] = ( + { + "type": parameter_type, + "description": parameter.llm_description or "", + } + if parameter.input_schema is None + else parameter.input_schema + ) if len(enum) > 0: message_tool.parameters["properties"][parameter.name]["enum"] = enum diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 895dd0d0fc..a323104295 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -40,6 +40,15 @@ class PluginParameterType(enum.StrEnum): SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value +class MCPServerParameterType(enum.StrEnum): + """ + MCP server got complex parameter types + """ + + ARRAY = "array" + OBJECT = "object" + + class PluginParameterAutoGenerate(BaseModel): class Type(enum.StrEnum): PROMPT_INSTRUCTION = "prompt_instruction" diff --git a/api/core/tools/__base/tool.py b/api/core/tools/__base/tool.py index 35e16b5c8f..8aef1078fe 100644 --- a/api/core/tools/__base/tool.py +++ b/api/core/tools/__base/tool.py @@ -148,6 +148,8 @@ class Tool(ABC): tool_parameter.default = parameter.default tool_parameter.options = parameter.options tool_parameter.llm_description = parameter.llm_description + if parameter.input_schema: + tool_parameter.input_schema = parameter.input_schema break else: # add new parameter diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index efc913c3b9..025a91fae5 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -40,7 +40,7 @@ class ToolProviderApiEntity(BaseModel): labels: list[str] = Field(default_factory=list) # MCP server_url: Optional[str] = Field(default="", description="The server url of the tool") - updated_at: datetime = Field(default_factory=datetime.now) + updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp())) @field_validator("tools", mode="before") @classmethod @@ -56,8 +56,12 @@ class ToolProviderApiEntity(BaseModel): for parameter in tool.get("parameters"): if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" + if parameter.get("input_schema") is None: + parameter.pop("input_schema", None) # ------------- optional_fields = self.optional_field("server_url", self.server_url) + if self.type == ToolProviderType.MCP.value: + optional_fields.update(self.optional_field("updated_at", self.updated_at)) return { "id": self.id, "author": self.author, diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index c683dbd087..f97ccbe808 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_seriali from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( + MCPServerParameterType, PluginParameter, PluginParameterOption, PluginParameterType, @@ -242,6 +243,10 @@ class ToolParameter(PluginParameter): APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + # MCP object and array type parameters + ARRAY = MCPServerParameterType.ARRAY.value + OBJECT = MCPServerParameterType.OBJECT.value + # deprecated, should not use. SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value @@ -260,6 +265,8 @@ class ToolParameter(PluginParameter): human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user") form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm") llm_description: Optional[str] = None + # MCP object and array type parameters use this field to store the schema + input_schema: Optional[dict] = None @classmethod def get_simple_instance( diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index ceddd0394c..1434236f03 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -88,7 +88,7 @@ class MCPToolManageService: icon=mcp_provider.icon, author=mcp_provider.user.name if mcp_provider.user else "Anonymous", server_url=mcp_provider.server_url, - updated_at=mcp_provider.updated_at, + updated_at=int(mcp_provider.updated_at.timestamp()), description=I18nObject(en_US="", zh_Hans=""), label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name), ) @@ -119,7 +119,6 @@ class MCPToolManageService: icon: str, icon_type: str, icon_background: str, - encrypted_credentials: dict, ): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) if mcp_provider is None: @@ -129,7 +128,6 @@ class MCPToolManageService: mcp_provider.icon = ( json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon ) - mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **encrypted_credentials}) db.session.commit() return {"result": "success"} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 54f82110b5..54f65d02a7 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from yarl import URL @@ -201,7 +201,7 @@ class ToolTransformService: tools=ToolTransformService.mcp_tool_to_user_tool( db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] ), - updated_at=db_provider.updated_at, + updated_at=int(db_provider.updated_at.timestamp()), label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), description=I18nObject(en_US="", zh_Hans=""), ) @@ -347,8 +347,11 @@ class ToolTransformService: :return: list of ToolParameter instances """ - def create_parameter(name: str, description: str, param_type: str, required: bool) -> ToolParameter: + def create_parameter( + name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None + ) -> ToolParameter: """Create a ToolParameter instance with given attributes""" + input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {} return ToolParameter( name=name, llm_description=description, @@ -357,36 +360,27 @@ class ToolTransformService: required=required, type=ToolParameter.ToolParameterType(param_type), human_description=I18nObject(en_US=description), + **input_schema_dict, ) - def process_array(name: str, description: str, items: dict, required: bool) -> list[ToolParameter]: - """Process array type properties""" - item_type = items.get("type", "string") - if item_type == "object" and "properties" in items: - return process_properties(items["properties"], items.get("required", []), f"{name}[0]") - - return [create_parameter(name, description, item_type, required)] - def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]: """Process properties recursively""" + TYPE_MAPPING = {"integer": "number"} + COMPLEX_TYPES = ["array", "object"] + parameters = [] for name, prop in props.items(): - current_name = f"{prefix}.{name}" if prefix else name current_description = prop.get("description", "") prop_type = prop.get("type", "string") if isinstance(prop_type, list): prop_type = prop_type[0] - if prop_type == "integer": - prop_type = "number" - if prop_type == "array": - parameters.extend( - process_array(current_name, current_description, prop.get("items", {}), name in required) - ) - elif prop_type == "object" and "properties" in prop: - parameters.extend(process_properties(prop["properties"], prop.get("required", []), current_name)) - else: - parameters.append(create_parameter(current_name, current_description, prop_type, name in required)) + if prop_type in TYPE_MAPPING: + prop_type = TYPE_MAPPING[prop_type] + input_schema = prop if prop_type in COMPLEX_TYPES else None + parameters.append( + create_parameter(name, current_description, prop_type, name in required, input_schema) + ) return parameters