mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 16:05:55 +08:00
feat: mcp tool add input schema
This commit is contained in:
parent
2e4dfbd60f
commit
1c84a27e7e
@ -655,7 +655,6 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return jsonable_encoder(
|
|
||||||
MCPToolManageService.update_mcp_provider(
|
MCPToolManageService.update_mcp_provider(
|
||||||
tenant_id=current_user.current_tenant_id,
|
tenant_id=current_user.current_tenant_id,
|
||||||
name=args["name"],
|
name=args["name"],
|
||||||
@ -664,9 +663,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
icon_type=args["icon_type"],
|
icon_type=args["icon_type"],
|
||||||
icon_background=args["icon_background"],
|
icon_background=args["icon_background"],
|
||||||
provider_id=args["provider_id"],
|
provider_id=args["provider_id"],
|
||||||
encrypted_credentials={},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@ -675,11 +673,8 @@ class ToolProviderMCPApi(Resource):
|
|||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return jsonable_encoder(
|
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||||
MCPToolManageService.delete_mcp_tool(
|
return {"result": "success"}
|
||||||
tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPAuthApi(Resource):
|
class ToolMCPAuthApi(Resource):
|
||||||
@ -739,7 +734,9 @@ class ToolMCPListAllApi(Resource):
|
|||||||
user = current_user
|
user = current_user
|
||||||
tenant_id = user.current_tenant_id
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
return jsonable_encoder(MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id))
|
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||||
|
|
||||||
|
return [tool.to_dict() for tool in tools]
|
||||||
|
|
||||||
|
|
||||||
class ToolMCPUpdateApi(Resource):
|
class ToolMCPUpdateApi(Resource):
|
||||||
@ -762,12 +759,16 @@ class ToolMCPTokenApi(Resource):
|
|||||||
def get(self):
|
def get(self):
|
||||||
parser = reqparse.RequestParser()
|
parser = reqparse.RequestParser()
|
||||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
|
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
|
||||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="args")
|
|
||||||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
|
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
provider = MCPToolManageService.get_mcp_provider_by_provider_id(
|
||||||
|
args["provider_id"], current_user.current_tenant_id
|
||||||
|
)
|
||||||
|
if not provider:
|
||||||
|
raise ValueError("provider not found")
|
||||||
return auth(
|
return auth(
|
||||||
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
|
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
|
||||||
server_url=args["server_url"],
|
provider.server_url,
|
||||||
authorization_code=args["authorization_code"],
|
authorization_code=args["authorization_code"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,10 +161,14 @@ class BaseAgentRunner(AppRunner):
|
|||||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||||
|
|
||||||
message_tool.parameters["properties"][parameter.name] = {
|
message_tool.parameters["properties"][parameter.name] = (
|
||||||
|
{
|
||||||
"type": parameter_type,
|
"type": parameter_type,
|
||||||
"description": parameter.llm_description or "",
|
"description": parameter.llm_description or "",
|
||||||
}
|
}
|
||||||
|
if parameter.input_schema is None
|
||||||
|
else parameter.input_schema
|
||||||
|
)
|
||||||
|
|
||||||
if len(enum) > 0:
|
if len(enum) > 0:
|
||||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||||
|
@ -40,6 +40,15 @@ class PluginParameterType(enum.StrEnum):
|
|||||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServerParameterType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
MCP server got complex parameter types
|
||||||
|
"""
|
||||||
|
|
||||||
|
ARRAY = "array"
|
||||||
|
OBJECT = "object"
|
||||||
|
|
||||||
|
|
||||||
class PluginParameterAutoGenerate(BaseModel):
|
class PluginParameterAutoGenerate(BaseModel):
|
||||||
class Type(enum.StrEnum):
|
class Type(enum.StrEnum):
|
||||||
PROMPT_INSTRUCTION = "prompt_instruction"
|
PROMPT_INSTRUCTION = "prompt_instruction"
|
||||||
|
@ -148,6 +148,8 @@ class Tool(ABC):
|
|||||||
tool_parameter.default = parameter.default
|
tool_parameter.default = parameter.default
|
||||||
tool_parameter.options = parameter.options
|
tool_parameter.options = parameter.options
|
||||||
tool_parameter.llm_description = parameter.llm_description
|
tool_parameter.llm_description = parameter.llm_description
|
||||||
|
if parameter.input_schema:
|
||||||
|
tool_parameter.input_schema = parameter.input_schema
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# add new parameter
|
# add new parameter
|
||||||
|
@ -40,7 +40,7 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
labels: list[str] = Field(default_factory=list)
|
labels: list[str] = Field(default_factory=list)
|
||||||
# MCP
|
# MCP
|
||||||
server_url: Optional[str] = Field(default="", description="The server url of the tool")
|
server_url: Optional[str] = Field(default="", description="The server url of the tool")
|
||||||
updated_at: datetime = Field(default_factory=datetime.now)
|
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||||
|
|
||||||
@field_validator("tools", mode="before")
|
@field_validator("tools", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -56,8 +56,12 @@ class ToolProviderApiEntity(BaseModel):
|
|||||||
for parameter in tool.get("parameters"):
|
for parameter in tool.get("parameters"):
|
||||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||||
parameter["type"] = "files"
|
parameter["type"] = "files"
|
||||||
|
if parameter.get("input_schema") is None:
|
||||||
|
parameter.pop("input_schema", None)
|
||||||
# -------------
|
# -------------
|
||||||
optional_fields = self.optional_field("server_url", self.server_url)
|
optional_fields = self.optional_field("server_url", self.server_url)
|
||||||
|
if self.type == ToolProviderType.MCP.value:
|
||||||
|
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"author": self.author,
|
"author": self.author,
|
||||||
|
@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_seriali
|
|||||||
|
|
||||||
from core.entities.provider_entities import ProviderConfig
|
from core.entities.provider_entities import ProviderConfig
|
||||||
from core.plugin.entities.parameters import (
|
from core.plugin.entities.parameters import (
|
||||||
|
MCPServerParameterType,
|
||||||
PluginParameter,
|
PluginParameter,
|
||||||
PluginParameterOption,
|
PluginParameterOption,
|
||||||
PluginParameterType,
|
PluginParameterType,
|
||||||
@ -242,6 +243,10 @@ class ToolParameter(PluginParameter):
|
|||||||
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
|
||||||
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
|
||||||
|
|
||||||
|
# MCP object and array type parameters
|
||||||
|
ARRAY = MCPServerParameterType.ARRAY.value
|
||||||
|
OBJECT = MCPServerParameterType.OBJECT.value
|
||||||
|
|
||||||
# deprecated, should not use.
|
# deprecated, should not use.
|
||||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||||
|
|
||||||
@ -260,6 +265,8 @@ class ToolParameter(PluginParameter):
|
|||||||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||||
llm_description: Optional[str] = None
|
llm_description: Optional[str] = None
|
||||||
|
# MCP object and array type parameters use this field to store the schema
|
||||||
|
input_schema: Optional[dict] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_simple_instance(
|
def get_simple_instance(
|
||||||
|
@ -88,7 +88,7 @@ class MCPToolManageService:
|
|||||||
icon=mcp_provider.icon,
|
icon=mcp_provider.icon,
|
||||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||||
server_url=mcp_provider.server_url,
|
server_url=mcp_provider.server_url,
|
||||||
updated_at=mcp_provider.updated_at,
|
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||||
)
|
)
|
||||||
@ -119,7 +119,6 @@ class MCPToolManageService:
|
|||||||
icon: str,
|
icon: str,
|
||||||
icon_type: str,
|
icon_type: str,
|
||||||
icon_background: str,
|
icon_background: str,
|
||||||
encrypted_credentials: dict,
|
|
||||||
):
|
):
|
||||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||||
if mcp_provider is None:
|
if mcp_provider is None:
|
||||||
@ -129,7 +128,6 @@ class MCPToolManageService:
|
|||||||
mcp_provider.icon = (
|
mcp_provider.icon = (
|
||||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||||
)
|
)
|
||||||
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **encrypted_credentials})
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ class ToolTransformService:
|
|||||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||||
),
|
),
|
||||||
updated_at=db_provider.updated_at,
|
updated_at=int(db_provider.updated_at.timestamp()),
|
||||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||||
description=I18nObject(en_US="", zh_Hans=""),
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
)
|
)
|
||||||
@ -347,8 +347,11 @@ class ToolTransformService:
|
|||||||
:return: list of ToolParameter instances
|
:return: list of ToolParameter instances
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_parameter(name: str, description: str, param_type: str, required: bool) -> ToolParameter:
|
def create_parameter(
|
||||||
|
name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
|
||||||
|
) -> ToolParameter:
|
||||||
"""Create a ToolParameter instance with given attributes"""
|
"""Create a ToolParameter instance with given attributes"""
|
||||||
|
input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
|
||||||
return ToolParameter(
|
return ToolParameter(
|
||||||
name=name,
|
name=name,
|
||||||
llm_description=description,
|
llm_description=description,
|
||||||
@ -357,36 +360,27 @@ class ToolTransformService:
|
|||||||
required=required,
|
required=required,
|
||||||
type=ToolParameter.ToolParameterType(param_type),
|
type=ToolParameter.ToolParameterType(param_type),
|
||||||
human_description=I18nObject(en_US=description),
|
human_description=I18nObject(en_US=description),
|
||||||
|
**input_schema_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_array(name: str, description: str, items: dict, required: bool) -> list[ToolParameter]:
|
|
||||||
"""Process array type properties"""
|
|
||||||
item_type = items.get("type", "string")
|
|
||||||
if item_type == "object" and "properties" in items:
|
|
||||||
return process_properties(items["properties"], items.get("required", []), f"{name}[0]")
|
|
||||||
|
|
||||||
return [create_parameter(name, description, item_type, required)]
|
|
||||||
|
|
||||||
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
|
||||||
"""Process properties recursively"""
|
"""Process properties recursively"""
|
||||||
|
TYPE_MAPPING = {"integer": "number"}
|
||||||
|
COMPLEX_TYPES = ["array", "object"]
|
||||||
|
|
||||||
parameters = []
|
parameters = []
|
||||||
for name, prop in props.items():
|
for name, prop in props.items():
|
||||||
current_name = f"{prefix}.{name}" if prefix else name
|
|
||||||
current_description = prop.get("description", "")
|
current_description = prop.get("description", "")
|
||||||
prop_type = prop.get("type", "string")
|
prop_type = prop.get("type", "string")
|
||||||
|
|
||||||
if isinstance(prop_type, list):
|
if isinstance(prop_type, list):
|
||||||
prop_type = prop_type[0]
|
prop_type = prop_type[0]
|
||||||
if prop_type == "integer":
|
if prop_type in TYPE_MAPPING:
|
||||||
prop_type = "number"
|
prop_type = TYPE_MAPPING[prop_type]
|
||||||
if prop_type == "array":
|
input_schema = prop if prop_type in COMPLEX_TYPES else None
|
||||||
parameters.extend(
|
parameters.append(
|
||||||
process_array(current_name, current_description, prop.get("items", {}), name in required)
|
create_parameter(name, current_description, prop_type, name in required, input_schema)
|
||||||
)
|
)
|
||||||
elif prop_type == "object" and "properties" in prop:
|
|
||||||
parameters.extend(process_properties(prop["properties"], prop.get("required", []), current_name))
|
|
||||||
else:
|
|
||||||
parameters.append(create_parameter(current_name, current_description, prop_type, name in required))
|
|
||||||
|
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user