mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 04:05:53 +08:00
feat: agent node add mcp tools
This commit is contained in:
parent
41bbcb9540
commit
c7cb3770a4
@ -53,7 +53,6 @@ class AppMCPServerController(Resource):
|
|||||||
)
|
)
|
||||||
db.session.add(server)
|
db.session.add(server)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|
||||||
@setup_required
|
@setup_required
|
||||||
@ -68,12 +67,17 @@ class AppMCPServerController(Resource):
|
|||||||
parser.add_argument("id", type=str, required=True, location="json")
|
parser.add_argument("id", type=str, required=True, location="json")
|
||||||
parser.add_argument("description", type=str, required=True, location="json")
|
parser.add_argument("description", type=str, required=True, location="json")
|
||||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||||
|
parser.add_argument("status", type=str, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||||
if not server:
|
if not server:
|
||||||
raise Forbidden()
|
raise Forbidden()
|
||||||
server.description = args["description"]
|
server.description = args["description"]
|
||||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||||
|
if args["status"]:
|
||||||
|
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||||
|
raise ValueError("Invalid status")
|
||||||
|
server.status = args["status"]
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return server
|
return server
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class RequestInvokeTool(BaseModel):
|
|||||||
Request to invoke a tool
|
Request to invoke a tool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tool_type: Literal["builtin", "workflow", "api"]
|
tool_type: Literal["builtin", "workflow", "api", "mcp"]
|
||||||
provider: str
|
provider: str
|
||||||
tool: str
|
tool: str
|
||||||
tool_parameters: dict
|
tool_parameters: dict
|
||||||
|
@ -53,7 +53,7 @@ class MCPToolProviderController(ToolProviderController):
|
|||||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||||
name=remote_mcp_tool.name,
|
name=remote_mcp_tool.name,
|
||||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||||
provider=db_provider.name,
|
provider=db_provider.id,
|
||||||
icon=db_provider.icon,
|
icon=db_provider.icon,
|
||||||
),
|
),
|
||||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||||
|
@ -746,7 +746,7 @@ class ToolManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if provider is None:
|
if provider is None:
|
||||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||||
|
|
||||||
controller = MCPToolProviderController._from_db(provider)
|
controller = MCPToolProviderController._from_db(provider)
|
||||||
|
|
||||||
|
@ -1,8 +1,21 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
from flask_restful import fields
|
from flask_restful import fields
|
||||||
|
|
||||||
from fields.workflow_fields import workflow_partial_fields
|
from fields.workflow_fields import workflow_partial_fields
|
||||||
from libs.helper import AppIconUrlField, TimestampField
|
from libs.helper import AppIconUrlField, TimestampField
|
||||||
|
|
||||||
|
|
||||||
|
class JsonStringField(fields.Raw):
|
||||||
|
def format(self, value):
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
app_detail_kernel_fields = {
|
app_detail_kernel_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"name": fields.String,
|
"name": fields.String,
|
||||||
@ -220,7 +233,7 @@ app_server_fields = {
|
|||||||
"server_code": fields.String,
|
"server_code": fields.String,
|
||||||
"description": fields.String,
|
"description": fields.String,
|
||||||
"status": fields.String,
|
"status": fields.String,
|
||||||
"parameters": fields.Raw,
|
"parameters": JsonStringField,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"updated_at": TimestampField,
|
"updated_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
@ -203,6 +203,7 @@ class MCPToolProvider(Base):
|
|||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
|
||||||
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
|
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
|
||||||
|
db.UniqueConstraint("server_url", name="unique_mcp_tool_provider_server_url"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||||
from core.mcp.mcp_client import MCPClient
|
from core.mcp.mcp_client import MCPClient
|
||||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||||
@ -30,12 +32,21 @@ class MCPToolManageService:
|
|||||||
def create_mcp_provider(
|
def create_mcp_provider(
|
||||||
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
|
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if (
|
existing_provider = (
|
||||||
db.session.query(MCPToolProvider)
|
db.session.query(MCPToolProvider)
|
||||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.name == name)
|
.filter(
|
||||||
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
|
or_(MCPToolProvider.name == name, MCPToolProvider.server_url == server_url),
|
||||||
|
MCPToolProvider.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
.first()
|
.first()
|
||||||
):
|
)
|
||||||
raise ValueError(f"MCP tool {name} already exists")
|
if existing_provider:
|
||||||
|
if existing_provider.name == name:
|
||||||
|
raise ValueError(f"MCP tool {name} already exists")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"MCP tool {server_url} already exists")
|
||||||
|
|
||||||
mcp_tool = MCPToolProvider(
|
mcp_tool = MCPToolProvider(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name=name,
|
name=name,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user