From beef679798c76a60c3bb8f31919681e103785219 Mon Sep 17 00:00:00 2001 From: Novice Date: Fri, 30 May 2025 18:26:28 +0800 Subject: [PATCH] feat: change the mcp server url --- api/controllers/mcp/__init__.py | 8 ++++ api/controllers/mcp/mcp.py | 66 +++++++++++++++++++++++++++++++ api/controllers/web/completion.py | 59 +-------------------------- api/core/mcp/server/handler.py | 16 ++++---- api/extensions/ext_blueprints.py | 2 + api/extensions/ext_login.py | 17 +++++++- 6 files changed, 101 insertions(+), 67 deletions(-) create mode 100644 api/controllers/mcp/__init__.py create mode 100644 api/controllers/mcp/mcp.py diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py new file mode 100644 index 0000000000..1b3e0a5621 --- /dev/null +++ b/api/controllers/mcp/__init__.py @@ -0,0 +1,8 @@ +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint("mcp", __name__, url_prefix="/mcp") +api = ExternalApi(bp) + +from . import mcp diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py new file mode 100644 index 0000000000..07aeec6fb8 --- /dev/null +++ b/api/controllers/mcp/mcp.py @@ -0,0 +1,66 @@ +from amqp import NotFound +from flask_restful import Resource, reqparse +from pydantic import ValidationError + +from controllers.mcp import api +from controllers.web.error import ( + AppUnavailableError, +) +from core.app.app_config.entities import VariableEntity +from core.mcp.server.handler import MCPServerReuqestHandler +from core.mcp.types import ClientRequest +from extensions.ext_database import db +from libs import helper +from models.model import App, AppMCPServer, AppMode + + +class MCPAppApi(Resource): + def post(self, server_code): + def int_or_str(value): + if isinstance(value, int): + return value + elif isinstance(value, str): + return int(value) + else: + raise ValueError("Invalid id") + + parser = reqparse.RequestParser() + parser.add_argument("jsonrpc", type=str, required=True, location="json") + parser.add_argument("method", type=str, required=True, location="json") + parser.add_argument("params", type=dict, required=True, location="json") + parser.add_argument("id", type=int_or_str, required=True, location="json") + args = parser.parse_args() + server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() + if not server: + raise NotFound("Server Not Found") + app = db.session.query(App).filter(App.id == server.app_id).first() + if not app: + raise NotFound("App Not Found") + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app.app_model_config + if app_model_config is None: + raise AppUnavailableError() + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + try: + user_input_form = [VariableEntity.model_validate(list(item.values())[0]) for item in user_input_form] + except ValidationError as e: + raise ValueError(f"Invalid user_input_form: {str(e)}") + try: + request = ClientRequest.model_validate(args) + except ValidationError as e: + raise ValueError(f"Invalid MCP request: {str(e)}") + mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) + return helper.compact_generate_response(mcp_server_handler.handle()) + + +api.add_resource(MCPAppApi, "/server//mcp") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index b893ec04d9..fd3b9aa804 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,6 @@ import logging -from flask_restful import Resource, reqparse -from pydantic import ValidationError +from flask_restful import reqparse from werkzeug.exceptions import InternalServerError, NotFound import services @@ -18,7 +17,6 @@ from controllers.web.error import ( ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.wraps import WebApiResource -from core.app.app_config.entities import VariableEntity from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -26,13 +24,10 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.mcp.server.handler import MCPServerReuqestHandler -from core.mcp.types import ClientRequest from core.model_runtime.errors.invoke import InvokeError -from extensions.ext_database import db from libs import helper from libs.helper import uuid_value -from models.model import App, AppMCPServer, AppMode +from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -154,57 +149,7 @@ class ChatStopApi(WebApiResource): return {"result": "success"}, 200 -class ChatMCPApi(Resource): - def post(self, server_code): - def int_or_str(value): - if isinstance(value, int): - return value - elif isinstance(value, str): - return int(value) - else: - raise ValueError("Invalid id") - - parser = reqparse.RequestParser() - parser.add_argument("jsonrpc", type=str, required=True, location="json") - parser.add_argument("method", type=str, required=True, location="json") - parser.add_argument("params", type=dict, required=True, location="json") - parser.add_argument("id", type=int_or_str, required=True, location="json") - args = parser.parse_args() - server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() - if not server: - raise NotFound("Server Not Found") - app = db.session.query(App).filter(App.id == server.app_id).first() - if not app: - raise NotFound("App Not Found") - if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: - workflow = app.workflow - if workflow is None: - raise AppUnavailableError() - - features_dict = workflow.features_dict - user_input_form = workflow.user_input_form(to_old_structure=True) - else: - app_model_config = app.app_model_config - if app_model_config is None: - raise AppUnavailableError() - - features_dict = app_model_config.to_dict() - - user_input_form = features_dict.get("user_input_form", []) - try: - user_input_form = [VariableEntity.model_validate(item) for item in user_input_form] - except ValidationError as e: - raise ValueError(f"Invalid user_input_form: {str(e)}") - try: - request = ClientRequest.model_validate(args) - except ValidationError as e: - raise ValueError(f"Invalid MCP request: {str(e)}") - mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) - return helper.compact_generate_response(mcp_server_handler.handle()) - - api.add_resource(CompletionApi, "/completion-messages") api.add_resource(CompletionStopApi, "/completion-messages//stop") api.add_resource(ChatApi, "/chat-messages") -api.add_resource(ChatMCPApi, "/server//mcp") api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index 668ac77948..e218a6ba6e 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -22,9 +22,9 @@ class MCPServerReuqestHandler: def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]): self.app = app self.request = request - self.mcp_server: AppMCPServer = self.app.mcp_server - if not self.mcp_server: + if not self.app.mcp_server: raise ValueError("MCP server not found") + self.mcp_server: AppMCPServer = self.app.mcp_server self.end_user = self.retrieve_end_user() self.user_input_form = user_input_form @@ -47,13 +47,9 @@ class MCPServerReuqestHandler: "required": required, }, }, - "required": "query", + "required": ["query", "inputs"], } - @property - def output_parameters(self): - return self.app.output_schema - @property def capabilities(self): return types.ServerCapabilities( @@ -160,6 +156,7 @@ class MCPServerReuqestHandler: parameters = {} required = [] for item in user_input_form: + parameters[item.variable] = {} if item.type in ( VariableEntityType.FILE, VariableEntityType.FILE_LIST, @@ -168,12 +165,13 @@ class MCPServerReuqestHandler: continue if item.required: required.append(item.variable) - parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"] + description = self.mcp_server.parameters_dict[item.label] + parameters[item.variable]["description"] = description if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): parameters[item.variable]["type"] = "string" elif item.type == VariableEntityType.SELECT: parameters[item.variable]["type"] = "string" parameters[item.variable]["enum"] = item.options elif item.type == VariableEntityType.NUMBER: - parameters[item.variable]["type"] = "number" + parameters[item.variable]["type"] = "float" return parameters, required diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py index 316be12f5c..a4d013ffc0 100644 --- a/api/extensions/ext_blueprints.py +++ b/api/extensions/ext_blueprints.py @@ -10,6 +10,7 @@ def init_app(app: DifyApp): from controllers.console import bp as console_app_bp from controllers.files import bp as files_bp from controllers.inner_api import bp as inner_api_bp + from controllers.mcp import bp as mcp_bp from controllers.service_api import bp as service_api_bp from controllers.web import bp as web_bp @@ -46,3 +47,4 @@ def init_app(app: DifyApp): app.register_blueprint(files_bp) app.register_blueprint(inner_api_bp) + app.register_blueprint(mcp_bp) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 06f42494ec..9629f0773f 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -10,7 +10,7 @@ from dify_app import DifyApp from extensions.ext_database import db from libs.passport import PassportService from models.account import Account, Tenant, TenantAccountJoin -from models.model import EndUser +from models.model import AppMCPServer, EndUser from services.account_service import AccountService login_manager = flask_login.LoginManager() @@ -71,6 +71,21 @@ def load_user_from_request(request_from_flask_login): if not end_user: raise NotFound("End user not found.") return end_user + elif request.blueprint == "mcp": + server_code = request.view_args.get("server_code") if request.view_args else None + if not server_code: + raise Unauthorized("Invalid Authorization token.") + app_mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() + if not app_mcp_server: + raise NotFound("App MCP server not found.") + end_user = ( + db.session.query(EndUser) + .filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") + .first() + ) + if not end_user: + raise NotFound("End user not found.") + return end_user @user_logged_in.connect