diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index a7a276edb4..aada292a86 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -68,16 +68,31 @@ class AppMCPServerController(Resource): parser.add_argument("id", type=str, required=True, location="json") parser.add_argument("description", type=str, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") - parser.add_argument("status", type=str, required=True, location="json") args = parser.parse_args() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() if not server: raise Forbidden() server.description = args["description"] server.parameters = json.dumps(args["parameters"], ensure_ascii=False) - server.status = AppMCPServerStatus(args["status"]) + db.session.commit() + return server + + +class AppMCPServerRefreshController(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_server_fields) + def get(self, server_id): + if not current_user.is_editor: + raise Forbidden() + server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first() + if not server: + raise Forbidden() + server.server_code = AppMCPServer.generate_server_code(16) db.session.commit() return server api.add_resource(AppMCPServerController, "/apps//server") +api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index ad397603a4..3628992816 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") args = parser.parse_args() user = current_user return jsonable_encoder( @@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return jsonable_encoder( @@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource): authed=False, authorization_code=args["authorization_code"], ): + MCPToolManageService.update_mcp_provider_credentials( + tenant_id=tenant_id, + provider_id=provider_id, + credentials={}, + authed=True, + ) return {"result": "success"} - except MCPAuthError as e: + + except MCPAuthError: auth_provider = OAuthClientProvider(provider_id, tenant_id) return auth(auth_provider, provider.server_url, args["authorization_code"]) diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index c8645d5ebe..b893ec04d9 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -18,6 +18,7 @@ from controllers.web.error import ( ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.wraps import WebApiResource +from core.app.app_config.entities import VariableEntity from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -175,11 +176,30 @@ class ChatMCPApi(Resource): app = db.session.query(App).filter(App.id == server.app_id).first() if not app: raise NotFound("App Not Found") + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: + workflow = app.workflow + if workflow is None: + raise AppUnavailableError() + + features_dict = workflow.features_dict + user_input_form = workflow.user_input_form(to_old_structure=True) + else: + app_model_config = app.app_model_config + if app_model_config is None: + raise AppUnavailableError() + + features_dict = app_model_config.to_dict() + + user_input_form = features_dict.get("user_input_form", []) + try: + user_input_form = [VariableEntity.model_validate(item) for item in user_input_form] + except ValidationError as e: + raise ValueError(f"Invalid user_input_form: {str(e)}") try: request = ClientRequest.model_validate(args) except ValidationError as e: raise ValueError(f"Invalid MCP request: {str(e)}") - mcp_server_handler = MCPServerReuqestHandler(app, request) + mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form) return helper.compact_generate_response(mcp_server_handler.handle()) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 11f245812e..51ddb343b3 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -108,3 +108,86 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): return make_request("HEAD", url, max_retries=max_retries, **kwargs) + + +def create_ssrf_proxy_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.Client: + """Create an HTTPX client with SSRF proxy configuration for MCP connections. + + Args: + headers: Optional headers to include in the client + timeout: Optional timeout configuration + + Returns: + Configured httpx.Client with proxy settings + """ + client_kwargs = { + "verify": HTTP_REQUEST_NODE_SSL_VERIFY, + "headers": headers or {}, + "timeout": timeout, + "follow_redirects": True, # Enable redirect following for MCP connections + } + + if dify_config.SSRF_PROXY_ALL_URL: + client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL + return httpx.Client(**client_kwargs) + elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), + "https://": httpx.HTTPTransport( + proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY + ), + } + client_kwargs["mounts"] = proxy_mounts + return httpx.Client(**client_kwargs) + else: + return httpx.Client(**client_kwargs) + + +def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + """Connect to SSE endpoint with SSRF proxy protection. + + This function creates an SSE connection using the configured proxy settings + to prevent SSRF attacks when connecting to external endpoints. + + Args: + url: The SSE endpoint URL + max_retries: Maximum number of retry attempts + **kwargs: Additional arguments passed to the SSE connection + + Returns: + EventSource object for SSE streaming + """ + from httpx_sse import connect_sse + + # Extract client if provided, otherwise create one + client = kwargs.pop("client", None) + if client is None: + # Create client with SSRF proxy configuration + timeout = kwargs.pop( + "timeout", + httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ), + ) + headers = kwargs.pop("headers", {}) + client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) + client_provided = False + else: + client_provided = True + + # Extract method if provided, default to GET + method = kwargs.pop("method", "GET") + + try: + return connect_sse(client, method, url, **kwargs) + except Exception as e: + # If we created the client, we need to clean it up on error + if not client_provided: + client.close() + raise diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index e40c0f47b6..b3e74ad11e 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -59,7 +59,6 @@ def start_authorization( metadata: Optional[OAuthMetadata], client_information: OAuthClientInformation, redirect_url: str, - scope: Optional[str] = None, ) -> tuple[str, str]: """Begins the authorization flow.""" response_type = "code" @@ -85,11 +84,9 @@ def start_authorization( "code_challenge": code_challenge, "code_challenge_method": code_challenge_method, "redirect_uri": redirect_url, + "state": "/tools?provider_id=" + client_information.client_id, } - if scope: - params["scope"] = scope - authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" return authorization_url, code_verifier @@ -187,7 +184,6 @@ def auth( provider: OAuthClientProvider, server_url: str, authorization_code: Optional[str] = None, - scope: Optional[str] = None, ) -> dict[str, str]: """Orchestrates the full auth flow with a server.""" metadata = discover_oauth_metadata(server_url) @@ -233,7 +229,6 @@ def auth( metadata, client_information, provider.redirect_url, - scope or provider.client_metadata.scope, ) provider.save_code_verifier(code_verifier) diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index e84124772f..3d14724845 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -1,6 +1,6 @@ from typing import Optional -from configs.app_config import DifyConfig +from configs import dify_config from core.mcp.types import ( OAuthClientInformation, OAuthClientInformationFull, @@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService LATEST_PROTOCOL_VERSION = "1.0" -dify_config = DifyConfig() - class OAuthClientProvider: provider_id: str @@ -25,7 +23,7 @@ class OAuthClientProvider: @property def redirect_url(self) -> str: """The URL to redirect the user agent to after authorization.""" - return dify_config.CONSOLE_WEB_URL + return dify_config.CONSOLE_WEB_URL + "/tools" @property def client_metadata(self) -> OAuthClientMetadata: @@ -37,7 +35,6 @@ class OAuthClientProvider: response_types=["code"], client_name="Dify", client_uri="https://github.com/langgenius/dify", - scope="read write", ) def client_information(self) -> Optional[OAuthClientInformation]: @@ -91,7 +88,3 @@ class OAuthClientProvider: if not mcp_provider: return "" return mcp_provider.credentials.get("code_verifier", "") - - -class UnauthorizedError(Exception): - pass diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 3ca719c8a3..955a695c29 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -1,6 +1,5 @@ import logging import queue -import threading from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -8,15 +7,21 @@ from typing import Any from urllib.parse import urljoin, urlparse import httpx -from httpx_sse import connect_sse from sseclient import SSEClient +from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect from core.mcp import types +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.types import SessionMessage -from core.mcp.utils import create_mcp_http_client, remove_request_params logger = logging.getLogger(__name__) +DEFAULT_QUEUE_READ_TIMEOUT = 3 + + +def remove_request_params(url: str) -> str: + return urljoin(url, urlparse(url).path) + @contextmanager def sse_client( @@ -36,65 +41,60 @@ def sse_client( read_queue = queue.Queue() write_queue = queue.Queue() status_queue = queue.Queue() - cancel_event = threading.Event() + with ThreadPoolExecutor() as executor: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - with create_mcp_http_client(headers=headers) as client: - with connect_sse( - client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + with create_ssrf_proxy_mcp_http_client(headers=headers) as client: + with ssrf_proxy_sse_connect( + url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client ) as event_source: event_source.response.raise_for_status() - logger.debug("SSE connection established") def sse_reader(status_queue: queue.Queue): try: - while not cancel_event.is_set(): - for sse in event_source.iter_sse(): - if cancel_event.is_set(): - break - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.info(f"Received endpoint URL: {endpoint_url}") - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - status_queue.put(("ready", endpoint_url)) - case "message": - try: - message = types.JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"Received server message: {message}") - except Exception as exc: - logger.exception("Error parsing server message") - read_queue.put(exc) - continue - session_message = SessionMessage(message) - read_queue.put(session_message) - case _: - logger.warning(f"Unknown SSE event: {sse.event}") + for sse in event_source.iter_sse(): + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.info(f"Received endpoint URL: {endpoint_url}") + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( + f"Endpoint origin does not match connection origin: {endpoint_url}" + ) + logger.error(error_msg) + status_queue.put(("error", ValueError(error_msg))) + status_queue.put(("ready", endpoint_url)) + case "message": + try: + message = types.JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"Received server message: {message}") + except Exception as exc: + logger.exception("Error parsing server message") + read_queue.put(exc) + continue + session_message = SessionMessage(message) + read_queue.put(session_message) + case _: + logger.warning(f"Unknown SSE event: {sse.event}") + except httpx.ReadError as exc: + logger.debug(f"SSE reader shutting down normally: {exc}") except Exception as exc: - if not cancel_event.is_set(): - logger.exception("Error reading SSE messages") - read_queue.put(exc) + read_queue.put(exc) finally: read_queue.put(None) def post_writer(endpoint_url: str): try: - while not cancel_event.is_set(): + while True: try: - message = write_queue.get(timeout=5) + message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if message is None: break response = client.post( @@ -107,14 +107,13 @@ def sse_client( ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") - if cancel_event.is_set(): - break except queue.Empty: - if cancel_event.is_set(): - break continue - except Exception: + except httpx.ReadError as exc: + logger.debug(f"SSE reader shutting down normally: {exc}") + except Exception as exc: logger.exception("Error writing messages") + write_queue.put(exc) finally: write_queue.put(None) @@ -125,11 +124,16 @@ def sse_client( raise ValueError("failed to get endpoint URL") if status != "ready": raise ValueError("failed to get endpoint URL") + if status == "error": + raise endpoint_url executor.submit(post_writer, endpoint_url) - try: - yield read_queue, write_queue - finally: - cancel_event.set() + + yield read_queue, write_queue + + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 401: + raise MCPAuthError() + raise MCPConnectionError() except Exception as exc: logger.exception("Error connecting to SSE endpoint") raise exc diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 048123c4f2..2e5b8cfcac 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -8,7 +8,6 @@ and session management. import logging import queue -import threading from collections.abc import Callable, Generator from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager @@ -17,8 +16,9 @@ from datetime import timedelta from typing import Any, cast import httpx -from httpx_sse import EventSource, ServerSentEvent, connect_sse +from httpx_sse import EventSource, ServerSentEvent +from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect from core.mcp.types import ( ClientMessageMetadata, ErrorData, @@ -30,7 +30,6 @@ from core.mcp.types import ( RequestId, SessionMessage, ) -from core.mcp.utils import create_mcp_http_client logger = logging.getLogger(__name__) @@ -50,6 +49,8 @@ ACCEPT = "Accept" JSON = "application/json" SSE = "text/event-stream" +DEFAULT_QUEUE_READ_TIMEOUT = 3 + class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -104,11 +105,6 @@ class StreamableHTTPTransport: CONTENT_TYPE: JSON, **self.headers, } - self.stop_event = threading.Event() - - def stop(self): - """Signal to stop all operations.""" - self.stop_event.set() def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: """Update headers with session ID if available.""" @@ -168,6 +164,9 @@ class StreamableHTTPTransport: # Put exception in queue that goes to client server_to_client_queue.put(exc) return False + elif sse.event == "ping": + logger.debug("Received ping event") + return False else: logger.warning(f"Unknown SSE event: {sse.event}") return False @@ -184,19 +183,18 @@ class StreamableHTTPTransport: headers = self._update_headers_with_session(self.request_headers) - with connect_sse( - client, - "GET", + with ssrf_proxy_sse_connect( self.url, + 2, headers=headers, timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), + client=client, + method="GET", ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") for sse in event_source.iter_sse(): - if self.stop_event.is_set(): - break self._handle_sse_event(sse, server_to_client_queue) except Exception as exc: @@ -215,19 +213,18 @@ class StreamableHTTPTransport: if isinstance(ctx.session_message.message.root, JSONRPCRequest): original_request_id = ctx.session_message.message.root.id - with connect_sse( - ctx.client, - "GET", + with ssrf_proxy_sse_connect( self.url, + 2, headers=headers, timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), + client=ctx.client, + method="GET", ) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") for sse in event_source.iter_sse(): - if self.stop_event.is_set(): - break is_complete = self._handle_sse_event( sse, ctx.server_to_client_queue, @@ -296,15 +293,14 @@ class StreamableHTTPTransport: try: event_source = EventSource(response) for sse in event_source.iter_sse(): - if self.stop_event.is_set(): - break - self._handle_sse_event( + is_complete = self._handle_sse_event( sse, ctx.server_to_client_queue, resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), ) + if is_complete: + break except Exception as e: - logger.exception("Error reading SSE stream:") ctx.server_to_client_queue.put(e) def _handle_unexpected_content_type( @@ -343,10 +339,10 @@ class StreamableHTTPTransport: This method processes messages from the client_to_server_queue and sends them to the server. Responses are written to the server_to_client_queue. """ - while not self.stop_event.is_set(): + while True: try: # Read message from client queue with timeout to check stop_event periodically - session_message = client_to_server_queue.get(timeout=5) + session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if session_message is None: break @@ -379,10 +375,8 @@ class StreamableHTTPTransport: else: self._handle_post_request(ctx) except queue.Empty: - # Timeout - continue loop to check stop_event continue except Exception as exc: - # Send exception to client server_to_client_queue.put(exc) def terminate_session(self, client: httpx.Client) -> None: @@ -444,7 +438,7 @@ def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - with create_mcp_http_client( + with create_ssrf_proxy_mcp_http_client( headers=transport.request_headers, timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), ) as client: @@ -475,9 +469,6 @@ def streamablehttp_client( # Signal threads to stop client_to_server_queue.put(None) finally: - # Clean up - transport.stop() - # Clear any remaining items and add None sentinel to unblock any waiting threads try: while not client_to_server_queue.empty(): diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 61c6457f2b..1c460b0721 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -1,6 +1,8 @@ import logging +from collections.abc import Callable from contextlib import ExitStack from typing import Optional, cast +from urllib.parse import urlparse from core.mcp.client.sse_client import sse_client from core.mcp.client.streamable_client import streamablehttp_client @@ -19,7 +21,6 @@ class MCPClient: tenant_id: str, authed: bool = True, authorization_code: Optional[str] = None, - scope: Optional[str] = None, ): # Initialize info self.provider_id = provider_id @@ -30,7 +31,6 @@ class MCPClient: # Authentication info self.authed = authed self.authorization_code = authorization_code - self.scope = scope if authed: from core.mcp.auth.auth_provider import OAuthClientProvider @@ -47,7 +47,7 @@ class MCPClient: self._initialized = False def __enter__(self): - self._initialize(first_try=True) + self._initialize() self._initialized = True return self @@ -56,44 +56,54 @@ class MCPClient: def _initialize( self, - first_try: bool = True, ): """Initialize the client with fallback to SSE if streamable connection fails""" - connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)] + connection_methods = {"mcp": streamablehttp_client, "sse": sse_client} + + parsed_url = urlparse(self.server_url) + path = parsed_url.path + method_name = path.rstrip("/").split("/")[-1] if path else "" + try: + client_factory = connection_methods[method_name] + self.connect_server(client_factory, method_name) + except KeyError: + try: + self.connect_server(sse_client, "sse") + except MCPConnectionError: + self.connect_server(streamablehttp_client, "mcp") + + def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True): from core.mcp.auth.auth_flow import auth - for method_name, client_factory in connection_methods: - try: - headers = ( - {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} - if self.authed and self.token - else {} - ) - self._streams_context = client_factory(url=self.server_url, headers=headers) - if method_name == "streamablehttp_client": - read_stream, write_stream, _ = self._streams_context.__enter__() - streams = (read_stream, write_stream) - else: # sse_client - streams = self._streams_context.__enter__() + try: + headers = ( + {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} + if self.authed and self.token + else {} + ) + self._streams_context = client_factory(url=self.server_url, headers=headers) + if method_name == "mcp": + read_stream, write_stream, _ = self._streams_context.__enter__() + streams = (read_stream, write_stream) + else: # sse_client + streams = self._streams_context.__enter__() - self._session_context = ClientSession(*streams) - self._session = self._session_context.__enter__() - session = cast(ClientSession, self._session) - session.initialize() - return + self._session_context = ClientSession(*streams) + self._session = self._session_context.__enter__() + session = cast(ClientSession, self._session) + session.initialize() + return - except MCPAuthError: - if not self.authed: - raise - - auth(self.provider, self.server_url, self.authorization_code, self.scope) - if first_try: - return self._initialize(first_try=False) - - except MCPConnectionError: - if method_name == "streamablehttp_client": - continue + except MCPAuthError: + if not self.authed: raise + auth(self.provider, self.server_url, self.authorization_code) + self.token = self.provider.tokens() + if first_try: + return self.connect_server(client_factory, method_name, first_try=False) + + except MCPConnectionError: + raise def list_tools(self) -> list[Tool]: """Connect to an MCP server running with SSE transport""" @@ -121,5 +131,6 @@ class MCPClient: self._session = None self._initialized = False self.exit_stack.close() - except Exception: + except Exception as e: logging.exception("Error during cleanup") + raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/handler.py index 579e25862b..668ac77948 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/handler.py @@ -2,30 +2,31 @@ import json from collections.abc import Mapping from typing import cast -from configs.app_config import DifyConfig +from configs import dify_config from controllers.web.passport import generate_session_id +from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.entities.app_invoke_entities import InvokeFrom from core.mcp import types from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db -from models.model import App, EndUser +from models.model import App, AppMCPServer, EndUser from services.app_generate_service import AppGenerateService """ Apply to MCP HTTP streamable server with stateless http """ -dify_config = DifyConfig() class MCPServerReuqestHandler: - def __init__(self, app: App, request: types.ClientRequest): + def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]): self.app = app self.request = request - if not self.app.mcp_server: + self.mcp_server: AppMCPServer = self.app.mcp_server + if not self.mcp_server: raise ValueError("MCP server not found") - self.mcp_server = self.app.mcp_server self.end_user = self.retrieve_end_user() + self.user_input_form = user_input_form @property def request_type(self): @@ -33,6 +34,7 @@ class MCPServerReuqestHandler: @property def parameter_schema(self): + parameters, required = self._convert_input_form_to_parameters(self.user_input_form) return { "type": "object", "properties": { @@ -41,10 +43,11 @@ class MCPServerReuqestHandler: "type": "object", "description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501 "default": {}, - # TODO: add input parameters + "properties": parameters, + "required": required, }, }, - "required": ["query"], + "required": "query", } @property @@ -152,3 +155,25 @@ class MCPServerReuqestHandler: .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .first() ) + + def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]): + parameters = {} + required = [] + for item in user_input_form: + if item.type in ( + VariableEntityType.FILE, + VariableEntityType.FILE_LIST, + VariableEntityType.EXTERNAL_DATA_TOOL, + ): + continue + if item.required: + required.append(item.variable) + parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"] + if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + parameters[item.variable]["type"] = "string" + elif item.type == VariableEntityType.SELECT: + parameters[item.variable]["type"] = "string" + parameters[item.variable]["enum"] = item.options + elif item.type == VariableEntityType.NUMBER: + parameters[item.variable]["type"] = "number" + return parameters, required diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 6cf6df4789..445de9e2a3 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -1,7 +1,5 @@ import logging import queue -import threading -import time from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack @@ -40,7 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) -DEFAULT_RESPONSE_READ_TIMEOUT = 5 +DEFAULT_RESPONSE_READ_TIMEOUT = 1 class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self._completed = False self._on_complete = on_complete self._entered = False # Track if we're in a context manager - self._cancel_event = threading.Event() def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": """Enter the context manager, enabling request cancellation tracking.""" self._entered = True - self._cancel_event = threading.Event() - self._cancel_event.clear() return self def __exit__( @@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): self._on_complete(self) finally: self._entered = False - if not self._cancel_event: - raise RuntimeError("No active cancel scope") - self._cancel_event.set() def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. @@ -117,17 +109,15 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" - if not self.cancelled: - self._completed = True + self._completed = True - self._session._send_response(request_id=self.request_id, response=response) + self._session._send_response(request_id=self.request_id, response=response) def cancel(self) -> None: """Cancel this request and mark it as completed.""" if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") - self._cancel_event.set() self._completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation self._session._send_response( @@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): response=ErrorData(code=0, message="Request cancelled", data=None), ) - @property - def in_flight(self) -> bool: - return not self._completed and not self.cancelled - - @property - def cancelled(self) -> bool: - return self._cancel_event.is_set() - class BaseSession( Generic[ @@ -184,11 +166,9 @@ class BaseSession( self._in_flight = {} self._exit_stack = ExitStack() self._futures = [] - self._request_id_lock = threading.Lock() def __enter__(self) -> Self: self._executor = ThreadPoolExecutor() - self._stop_event = threading.Event() self._receiver_future = self._executor.submit(self._receive_loop) return self @@ -196,21 +176,8 @@ class BaseSession( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: self._exit_stack.close() - self._stop_event.set() - self._wait_for_futures(timeout=5) - - def _wait_for_futures(self, timeout=None): - end_time = time.time() + timeout if timeout else None - - for future in list(self._futures): - try: - remaining = end_time - time.time() if end_time else None - if remaining is not None and remaining <= 0: - break - - future.result(timeout=remaining) - except Exception as e: - print(f"Error waiting for task: {e}") + self._read_stream.put(None) + self._write_stream.put(None) def send_request( self, @@ -247,8 +214,12 @@ class BaseSession( timeout = request_read_timeout_seconds.total_seconds() elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - - response_or_error = response_queue.get(timeout=timeout) + while True: + try: + response_or_error = response_queue.get(timeout=timeout) + break + except queue.Empty: + continue if response_or_error is None: raise MCPConnectionError( @@ -312,10 +283,10 @@ class BaseSession( Main message processing loop. In a real synchronous implementation, this would likely run in a separate thread. """ - while not self._stop_event.is_set(): + while True: try: # Attempt to receive a message (this would be blocking in a synchronous context) - message = self._read_stream.get(timeout=5) + message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT) if message is None: break if isinstance(message, HTTPStatusError): @@ -374,12 +345,9 @@ class BaseSession( else: self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) except queue.Empty: - if self._stop_event.is_set(): - break continue except Exception as e: logging.exception("Error in message processing loop") - self._stop_event.set() def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 6805f4a039..7c8827a77f 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -3,12 +3,11 @@ from typing import Any, Protocol from pydantic import AnyUrl, TypeAdapter -from configs.app_config import DifyConfig +from configs import dify_config from core.mcp import types from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext from core.mcp.session.base_session import BaseSession, RequestResponder -dify_config = DifyConfig() DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION) diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py index 6f2107c318..b8e2494f37 100644 --- a/api/core/mcp/types.py +++ b/api/core/mcp/types.py @@ -1177,7 +1177,6 @@ class SessionMessage: class OAuthClientMetadata(BaseModel): client_name: str redirect_uris: list[str] - scope: str grant_types: Optional[list[str]] = None response_types: Optional[list[str]] = None token_endpoint_auth_method: Optional[str] = None diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py deleted file mode 100644 index c6e9dd21ac..0000000000 --- a/api/core/mcp/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any -from urllib.parse import urljoin, urlparse - -import httpx - - -def create_mcp_http_client( - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, -) -> httpx.Client: - kwargs: dict[str, Any] = { - "follow_redirects": True, - } - - # Handle timeout - if timeout is None: - kwargs["timeout"] = httpx.Timeout(30.0) - else: - kwargs["timeout"] = timeout - - # Handle headers - if headers is not None: - kwargs["headers"] = headers - return httpx.Client(**kwargs) - - -def remove_request_params(url: str) -> str: - return urljoin(url, urlparse(url).path) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 1087ecb712..efc913c3b9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Literal, Optional +from typing import Any, Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel): if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: parameter["type"] = "files" # ------------- - + optional_fields = self.optional_field("server_url", self.server_url) return { "id": self.id, "author": self.author, @@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel): "allow_delete": self.allow_delete, "tools": tools, "labels": self.labels, + **optional_fields, } + + def optional_field(self, key: str, value: Any) -> dict: + """Return dict with key-value if value is truthy, empty dict otherwise.""" + return {key: value} if value else {} diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 2248430743..a2c5e159c6 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -1,6 +1,7 @@ from collections.abc import Generator from typing import Any, Optional +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient from core.mcp.types import ImageContent, TextContent from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -37,9 +38,14 @@ class MCPTool(Tool): app_id: Optional[str] = None, message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: - with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: - tool_parameters = convert_parameters_to_plugin_format(tool_parameters) - result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + try: + with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: + tool_parameters = convert_parameters_to_plugin_format(tool_parameters) + result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) + except MCPAuthError as e: + raise ValueError("Please auth the tool first") + except MCPConnectionError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") for content in result.content: if isinstance(content, TextContent): yield self.create_text_message(content.text) diff --git a/api/models/model.py b/api/models/model.py index 43ebc41b45..4b412c6105 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1471,6 +1471,10 @@ class AppMCPServer(Base): return result + @property + def parameters_dict(self) -> dict[str, Any]: + return json.loads(self.parameters) + class Site(Base): __tablename__ = "sites" diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 1bf4e8537e..a8d2a7425f 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -1,5 +1,6 @@ import json +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -58,8 +59,13 @@ class MCPToolManageService: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) if mcp_provider is None: raise ValueError("MCP tool not found") - with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: - tools = mcp_client.list_tools() + try: + with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: + tools = mcp_client.list_tools() + except MCPAuthError as e: + raise ValueError("Please auth the tool first") + except MCPConnectionError as e: + raise ValueError(f"Failed to connect to MCP server: {e}") mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.authed = True db.session.commit()