From 41bbcb9540eb84564a8494cd273c52eeaec611c9 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 27 May 2025 13:14:51 +0800 Subject: [PATCH] feat: upgrade streamable http client --- api/controllers/console/app/mcp_server.py | 19 +++- api/controllers/web/completion.py | 22 ++++- api/core/helper/ssrf_proxy.py | 1 + api/core/mcp/auth/auth_flow.py | 7 +- api/core/mcp/auth/auth_provider.py | 11 +-- api/core/mcp/client/sse_client.py | 92 +++++++++---------- api/core/mcp/client/streamable_client.py | 26 ++---- api/core/mcp/mcp_client.py | 16 ++-- api/core/mcp/server/handler.py | 41 +++++++-- api/core/mcp/session/base_session.py | 48 ++-------- api/core/mcp/session/client_session.py | 3 +- api/core/mcp/types.py | 1 - api/core/tools/entities/api_entities.py | 9 +- api/core/tools/mcp_tool/tool.py | 12 ++- api/models/model.py | 4 + api/services/tools/mcp_tools_mange_service.py | 10 +- 16 files changed, 167 insertions(+), 155 deletions(-) diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index a7a276edb4..aada292a86 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -68,16 +68,31 @@ class AppMCPServerController(Resource): parser.add_argument("id", type=str, required=True, location="json") parser.add_argument("description", type=str, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") - parser.add_argument("status", type=str, required=True, location="json") args = parser.parse_args() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() if not server: raise Forbidden() server.description = args["description"] server.parameters = json.dumps(args["parameters"], ensure_ascii=False) - server.status = AppMCPServerStatus(args["status"]) + db.session.commit() + return server + + +class AppMCPServerRefreshController(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_server_fields) + def get(self, server_id): + if not current_user.is_editor: + raise Forbidden() + server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first() + if not server: + raise Forbidden() + server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() return server api.add_resource(AppMCPServerController, "/apps//server") +api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index c8645d5ebe..b893ec04d9 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -18,6 +18,7 @@ from controllers.web.error import ( ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.wraps import WebApiResource +from core.app.app_config.entities import VariableEntity from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -175,11 +176,30 @@ class ChatMCPApi(Resource): app = db.session.query(App).filter(App.id == server.app_id).first() if not app: raise NotFound("App Not Found") + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app.app_model_config + if app_model_config is None: + raise AppUnavailableError() + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + try: + user_input_form = [VariableEntity.model_validate(item) for item in user_input_form] + except ValidationError as e: + raise ValueError(f"Invalid user_input_form: {str(e)}") try: request = ClientRequest.model_validate(args) except ValidationError as e: raise ValueError(f"Invalid MCP request: {str(e)}") - mcp_server_handler = MCPServerReuqestHandler(app, request) + mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) return helper.compact_generate_response(mcp_server_handler.handle()) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index d8f60c5c19..51ddb343b3 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -127,6 +127,7 @@ def create_ssrf_proxy_mcp_http_client( "verify": HTTP_REQUEST_NODE_SSL_VERIFY, "headers": headers or {}, "timeout": timeout, + "follow_redirects": True, # Enable redirect following for MCP connections } if dify_config.SSRF_PROXY_ALL_URL: diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index e40c0f47b6..b3e74ad11e 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -59,7 +59,6 @@ def start_authorization( metadata: Optional[OAuthMetadata], client_information: OAuthClientInformation, redirect_url: str, - scope: Optional[str] = None, ) -> tuple[str, str]: """Begins the authorization flow.""" response_type = "code" @@ -85,11 +84,9 @@ def start_authorization( "code_challenge": code_challenge, "code_challenge_method": code_challenge_method, "redirect_uri": redirect_url, + "state": "/tools?provider_id=" + client_information.client_id, } - if scope: - params["scope"] = scope - authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" return authorization_url, code_verifier @@ -187,7 +184,6 @@ def auth( provider: OAuthClientProvider, server_url: str, authorization_code: Optional[str] = None, - scope: Optional[str] = None, ) -> dict[str, str]: """Orchestrates the full auth flow with a server.""" metadata = discover_oauth_metadata(server_url) @@ -233,7 +229,6 @@ def auth( metadata, client_information, provider.redirect_url, - scope or provider.client_metadata.scope, ) provider.save_code_verifier(code_verifier) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index e84124772f..3d14724845 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from configs.app_config import DifyConfig +from configs import dify_config from core.mcp.types import ( OAuthClientInformation, OAuthClientInformationFull, @@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService LATEST_PROTOCOL_VERSION = "1.0" -dify_config = DifyConfig() - class OAuthClientProvider: provider_id: str @@ -25,7 +23,7 @@ class OAuthClientProvider: @property def redirect_url(self) -> str: """The URL to redirect the user agent to after authorization.""" - return dify_config.CONSOLE_WEB_URL + return dify_config.CONSOLE_WEB_URL + "/tools" @property def client_metadata(self) -> OAuthClientMetadata: @@ -37,7 +35,6 @@ class OAuthClientProvider: response_types=["code"], client_name="Dify", client_uri="https://github.com/langgenius/dify", - scope="read write", ) def client_information(self) -> Optional[OAuthClientInformation]: @@ -91,7 +88,3 @@ class OAuthClientProvider: if not mcp_provider: return "" return mcp_provider.credentials.get("code_verifier", "") - - -class UnauthorizedError(Exception): - pass diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 0ad3eb4976..955a695c29 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -1,6 +1,5 @@ import logging import queue -import threading from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -42,7 +41,7 @@ def sse_client( read_queue = queue.Queue() write_queue = queue.Queue() status_queue = queue.Queue() - cancel_event = threading.Event() + with ThreadPoolExecutor() as executor: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") @@ -51,54 +50,49 @@ def sse_client( url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client ) as event_source: event_source.response.raise_for_status() - logger.debug("SSE connection established") def sse_reader(status_queue: queue.Queue): try: - while not cancel_event.is_set(): - for sse in event_source.iter_sse(): - if cancel_event.is_set(): - break - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.info(f"Received endpoint URL: {endpoint_url}") - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - status_queue.put(("ready", endpoint_url)) - case "message": - try: - message = types.JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"Received server message: {message}") - except Exception as exc: - logger.exception("Error parsing server message") - read_queue.put(exc) - continue - session_message = SessionMessage(message) - read_queue.put(session_message) - case _: - logger.warning(f"Unknown SSE event: {sse.event}") + for sse in event_source.iter_sse(): + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.info(f"Received endpoint URL: {endpoint_url}") + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( + f"Endpoint origin does not match connection origin: {endpoint_url}" + ) + logger.error(error_msg) + status_queue.put(("error", ValueError(error_msg))) + status_queue.put(("ready", endpoint_url)) + case "message": + try: + message = types.JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"Received server message: {message}") + except Exception as exc: + logger.exception("Error parsing server message") + read_queue.put(exc) + continue + session_message = SessionMessage(message) + read_queue.put(session_message) + case _: + logger.warning(f"Unknown SSE event: {sse.event}") + except httpx.ReadError as exc: + logger.debug(f"SSE reader shutting down normally: {exc}") except Exception as exc: - if not cancel_event.is_set(): - logger.exception("Error reading SSE messages") - read_queue.put(exc) + read_queue.put(exc) finally: read_queue.put(None) def post_writer(endpoint_url: str): try: - while not cancel_event.is_set(): + while True: try: message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if message is None: @@ -113,14 +107,13 @@ def sse_client( ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") - if cancel_event.is_set(): - break except queue.Empty: - if cancel_event.is_set(): - break continue - except Exception: + except httpx.ReadError as exc: + logger.debug(f"SSE reader shutting down normally: {exc}") + except Exception as exc: logger.exception("Error writing messages") + write_queue.put(exc) finally: write_queue.put(None) @@ -131,11 +124,12 @@ def sse_client( raise ValueError("failed to get endpoint URL") if status != "ready": raise ValueError("failed to get endpoint URL") + if status == "error": + raise endpoint_url executor.submit(post_writer, endpoint_url) - try: - yield read_queue, write_queue - finally: - cancel_event.set() + + yield read_queue, write_queue + except httpx.HTTPStatusError as exc: if exc.response.status_code == 401: raise MCPAuthError() diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 1f46ee948f..2e5b8cfcac 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -8,7 +8,6 @@ and session management. import logging import queue -import threading from collections.abc import Callable, Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -106,11 +105,6 @@ class StreamableHTTPTransport: CONTENT_TYPE: JSON, **self.headers, } - self.stop_event = threading.Event() - - def stop(self): - """Signal to stop all operations.""" - self.stop_event.set() def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: """Update headers with session ID if available.""" @@ -170,6 +164,9 @@ class StreamableHTTPTransport: # Put exception in queue that goes to client server_to_client_queue.put(exc) return False + elif sse.event == "ping": + logger.debug("Received ping event") + return False else: logger.warning(f"Unknown SSE event: {sse.event}") return False @@ -198,8 +195,6 @@ class StreamableHTTPTransport: logger.debug("GET SSE connection established") for sse in event_source.iter_sse(): - if self.stop_event.is_set(): - break self._handle_sse_event(sse, server_to_client_queue) except Exception as exc: @@ -230,8 +225,6 @@ class StreamableHTTPTransport: logger.debug("Resumption GET SSE connection established") for sse in event_source.iter_sse(): - if self.stop_event.is_set(): - break is_complete = self._handle_sse_event( sse, ctx.server_to_client_queue, @@ -300,13 +293,13 @@ class StreamableHTTPTransport: try: event_source = EventSource(response) for sse in event_source.iter_sse(): - if self.stop_event.is_set(): - break - self._handle_sse_event( + is_complete = self._handle_sse_event( sse, ctx.server_to_client_queue, resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), ) + if is_complete: + break except Exception as e: ctx.server_to_client_queue.put(e) @@ -346,7 +339,7 @@ class StreamableHTTPTransport: This method processes messages from the client_to_server_queue and sends them to the server. Responses are written to the server_to_client_queue. """ - while not self.stop_event.is_set(): + while True: try: # Read message from client queue with timeout to check stop_event periodically session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) @@ -382,10 +375,8 @@ class StreamableHTTPTransport: else: self._handle_post_request(ctx) except queue.Empty: - # Timeout - continue loop to check stop_event continue except Exception as exc: - # Send exception to client server_to_client_queue.put(exc) def terminate_session(self, client: httpx.Client) -> None: @@ -478,9 +469,6 @@ def streamablehttp_client( # Signal threads to stop client_to_server_queue.put(None) finally: - # Clean up - transport.stop() - # Clear any remaining items and add None sentinel to unblock any waiting threads try: while not client_to_server_queue.empty(): diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 56aa16f873..1c460b0721 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -21,7 +21,6 @@ class MCPClient: tenant_id: str, authed: bool = True, authorization_code: Optional[str] = None, - scope: Optional[str] = None, ): # Initialize info self.provider_id = provider_id @@ -32,7 +31,6 @@ class MCPClient: # Authentication info self.authed = authed self.authorization_code = authorization_code - self.scope = scope if authed: from core.mcp.auth.auth_provider import OAuthClientProvider @@ -49,7 +47,7 @@ class MCPClient: self._initialized = False def __enter__(self): - self._initialize(first_try=True) + self._initialize() self._initialized = True return self @@ -58,7 +56,6 @@ class MCPClient: def _initialize( self, - first_try: bool = True, ): """Initialize the client with fallback to SSE if streamable connection fails""" connection_methods = {"mcp": streamablehttp_client, "sse": sse_client} @@ -71,9 +68,9 @@ class MCPClient: self.connect_server(client_factory, method_name) except KeyError: try: - self.connect_server(streamablehttp_client, "sse") + self.connect_server(sse_client, "sse") except MCPConnectionError: - self.connect_server(sse_client, "mcp") + self.connect_server(streamablehttp_client, "mcp") def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True): from core.mcp.auth.auth_flow import auth @@ -100,8 +97,8 @@ class MCPClient: except MCPAuthError: if not self.authed: raise - - auth(self.provider, self.server_url, self.authorization_code, self.scope) + auth(self.provider, self.server_url, self.authorization_code) + self.token = self.provider.tokens() if first_try: return self.connect_server(client_factory, method_name, first_try=False) @@ -134,5 +131,6 @@ class MCPClient: self._session = None self._initialized = False self.exit_stack.close() - except Exception: + except Exception as e: logging.exception("Error during cleanup") + raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index 579e25862b..668ac77948 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -2,30 +2,31 @@ import json from collections.abc import Mapping from typing import cast -from configs.app_config import DifyConfig +from configs import dify_config from controllers.web.passport import generate_session_id +from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.entities.app_invoke_entities import InvokeFrom from core.mcp import types from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db -from models.model import App, EndUser +from models.model import App, AppMCPServer, EndUser from services.app_generate_service import AppGenerateService """ Apply to MCP HTTP streamable server with stateless http """ -dify_config = DifyConfig() class MCPServerReuqestHandler: - def __init__(self, app: App, request: types.ClientRequest): + def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]): self.app = app self.request = request - if not self.app.mcp_server: + self.mcp_server: AppMCPServer = self.app.mcp_server + if not self.mcp_server: raise ValueError("MCP server not found") - self.mcp_server = self.app.mcp_server self.end_user = self.retrieve_end_user() + self.user_input_form = user_input_form @property def request_type(self): @@ -33,6 +34,7 @@ class MCPServerReuqestHandler: @property def parameter_schema(self): + parameters, required = self._convert_input_form_to_parameters(self.user_input_form) return { "type": "object", "properties": { @@ -41,10 +43,11 @@ class MCPServerReuqestHandler: "type": "object", "description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501 "default": {}, - # TODO: add input parameters + "properties": parameters, + "required": required, }, }, - "required": ["query"], + "required": "query", } @property @@ -152,3 +155,25 @@ class MCPServerReuqestHandler: .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .first() ) + + def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): + parameters = {} + required = [] + for item in user_input_form: + if item.type in ( + VariableEntityType.FILE, + VariableEntityType.FILE_LIST, + VariableEntityType.EXTERNAL_DATA_TOOL, + ): + continue + if item.required: + required.append(item.variable) + parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"] + if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + parameters[item.variable]["type"] = "string" + elif item.type == VariableEntityType.SELECT: + parameters[item.variable]["type"] = "string" + parameters[item.variable]["enum"] = item.options + elif item.type == VariableEntityType.NUMBER: + parameters[item.variable]["type"] = "number" + return parameters, required diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 978f40be5b..445de9e2a3 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -1,7 +1,5 @@ import logging import queue -import threading -import time from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack @@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self._completed = False self._on_complete = on_complete self._entered = False # Track if we're in a context manager - self._cancel_event = threading.Event() def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": """Enter the context manager, enabling request cancellation tracking.""" self._entered = True - self._cancel_event = threading.Event() - self._cancel_event.clear() return self def __exit__( @@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self._on_complete(self) finally: self._entered = False - if not self._cancel_event: - raise RuntimeError("No active cancel scope") - self._cancel_event.set() def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. @@ -117,17 +109,15 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" - if not self.cancelled: - self._completed = True + self._completed = True - self._session._send_response(request_id=self.request_id, response=response) + self._session._send_response(request_id=self.request_id, response=response) def cancel(self) -> None: """Cancel this request and mark it as completed.""" if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - self._cancel_event.set() self._completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation self._session._send_response( @@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): response=ErrorData(code=0, message="Request cancelled", data=None), ) - @property - def in_flight(self) -> bool: - return not self._completed and not self.cancelled - - @property - def cancelled(self) -> bool: - return self._cancel_event.is_set() - class BaseSession( Generic[ @@ -184,11 +166,9 @@ class BaseSession( self._in_flight = {} self._exit_stack = ExitStack() self._futures = [] - self._request_id_lock = threading.Lock() def __enter__(self) -> Self: self._executor = ThreadPoolExecutor() - self._stop_event = threading.Event() self._receiver_future = self._executor.submit(self._receive_loop) return self @@ -196,21 +176,8 @@ class BaseSession( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: self._exit_stack.close() - self._stop_event.set() - self._wait_for_futures(timeout=5) - - def _wait_for_futures(self, timeout=None): - end_time = time.time() + timeout if timeout else None - - for future in list(self._futures): - try: - remaining = end_time - time.time() if end_time else None - if remaining is not None and remaining <= 0: - break - - future.result(timeout=remaining) - except Exception as e: - logging.exception(f"Error waiting for task: {e}") + self._read_stream.put(None) + self._write_stream.put(None) def send_request( self, @@ -247,7 +214,7 @@ class BaseSession( timeout = request_read_timeout_seconds.total_seconds() elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - while not self._stop_event.is_set(): + while True: try: response_or_error = response_queue.get(timeout=timeout) break @@ -316,7 +283,7 @@ class BaseSession( Main message processing loop. In a real synchronous implementation, this would likely run in a separate thread. """ - while not self._stop_event.is_set(): + while True: try: # Attempt to receive a message (this would be blocking in a synchronous context) message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT) @@ -378,12 +345,9 @@ class BaseSession( else: self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) except queue.Empty: - if self._stop_event.is_set(): - break continue except Exception as e: logging.exception("Error in message processing loop") - self._stop_event.set() def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 6805f4a039..7c8827a77f 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -3,12 +3,11 @@ from typing import Any, Protocol from pydantic import AnyUrl, TypeAdapter -from configs.app_config import DifyConfig +from configs import dify_config from core.mcp import types from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext from core.mcp.session.base_session import BaseSession, RequestResponder -dify_config = DifyConfig() DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION) diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 6f2107c318..b8e2494f37 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1177,7 +1177,6 @@ class SessionMessage: class OAuthClientMetadata(BaseModel): client_name: str redirect_uris: list[str] - scope: str grant_types: Optional[list[str]] = None response_types: Optional[list[str]] = None token_endpoint_auth_method: Optional[str] = None diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 1087ecb712..efc913c3b9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Literal, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel): if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" # ------------- - + optional_fields = self.optional_field("server_url", self.server_url) return { "id": self.id, "author": self.author, @@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel): "allow_delete": self.allow_delete, "tools": tools, "labels": self.labels, + **optional_fields, } + + def optional_field(self, key: str, value: Any) -> dict: + """Return dict with key-value if value is truthy, empty dict otherwise.""" + return {key: value} if value else {} diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 2248430743..a2c5e159c6 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,6 +1,7 @@ from collections.abc import Generator from typing import Any, Optional +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient from core.mcp.types import ImageContent, TextContent from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -37,9 +38,14 @@ class MCPTool(Tool): app_id: Optional[str] = None, message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: - with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: - tool_parameters = convert_parameters_to_plugin_format(tool_parameters) - result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + try: + with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: + tool_parameters = convert_parameters_to_plugin_format(tool_parameters) + result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPAuthError as e: + raise ValueError("Please auth the tool first") + except MCPConnectionError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") for content in result.content: if isinstance(content, TextContent): yield self.create_text_message(content.text) diff --git a/api/models/model.py b/api/models/model.py index c8df44bfec..516ea5afc2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1461,6 +1461,10 @@ class AppMCPServer(Base): return result + @property + def parameters_dict(self) -> dict[str, Any]: + return json.loads(self.parameters) + class Site(Base): __tablename__ = "sites" diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 1bf4e8537e..a8d2a7425f 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -1,5 +1,6 @@ import json +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -58,8 +59,13 @@ class MCPToolManageService: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) if mcp_provider is None: raise ValueError("MCP tool not found") - with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: - tools = mcp_client.list_tools() + try: + with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: + tools = mcp_client.list_tools() + except MCPAuthError as e: + raise ValueError("Please auth the tool first") + except MCPConnectionError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.authed = True db.session.commit()