diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index ad397603a4..3628992816 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="") args = parser.parse_args() user = current_user return jsonable_encoder( @@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource): parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") - parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return jsonable_encoder( @@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource): authed=False, authorization_code=args["authorization_code"], ): + MCPToolManageService.update_mcp_provider_credentials( + tenant_id=tenant_id, + provider_id=provider_id, + credentials={}, + authed=True, + ) return {"result": "success"} - except MCPAuthError as e: + + except MCPAuthError: auth_provider = OAuthClientProvider(provider_id, tenant_id) return auth(auth_provider, provider.server_url, args["authorization_code"]) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 11f245812e..d8f60c5c19 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -108,3 +108,85 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): return make_request("HEAD", url, max_retries=max_retries, **kwargs) + + +def create_ssrf_proxy_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.Client: + """Create an HTTPX client with SSRF proxy configuration for MCP connections. + + Args: + headers: Optional headers to include in the client + timeout: Optional timeout configuration + + Returns: + Configured httpx.Client with proxy settings + """ + client_kwargs = { + "verify": HTTP_REQUEST_NODE_SSL_VERIFY, + "headers": headers or {}, + "timeout": timeout, + } + + if dify_config.SSRF_PROXY_ALL_URL: + client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL + return httpx.Client(**client_kwargs) + elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: + proxy_mounts = { + "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY), + "https://": httpx.HTTPTransport( + proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY + ), + } + client_kwargs["mounts"] = proxy_mounts + return httpx.Client(**client_kwargs) + else: + return httpx.Client(**client_kwargs) + + +def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): + """Connect to SSE endpoint with SSRF proxy protection. + + This function creates an SSE connection using the configured proxy settings + to prevent SSRF attacks when connecting to external endpoints. + + Args: + url: The SSE endpoint URL + max_retries: Maximum number of retry attempts + **kwargs: Additional arguments passed to the SSE connection + + Returns: + EventSource object for SSE streaming + """ + from httpx_sse import connect_sse + + # Extract client if provided, otherwise create one + client = kwargs.pop("client", None) + if client is None: + # Create client with SSRF proxy configuration + timeout = kwargs.pop( + "timeout", + httpx.Timeout( + timeout=dify_config.SSRF_DEFAULT_TIME_OUT, + connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT, + read=dify_config.SSRF_DEFAULT_READ_TIME_OUT, + write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, + ), + ) + headers = kwargs.pop("headers", {}) + client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout) + client_provided = False + else: + client_provided = True + + # Extract method if provided, default to GET + method = kwargs.pop("method", "GET") + + try: + return connect_sse(client, method, url, **kwargs) + except Exception as e: + # If we created the client, we need to clean it up on error + if not client_provided: + client.close() + raise diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py index 3ca719c8a3..0ad3eb4976 100644 --- a/api/core/mcp/client/sse_client.py +++ b/api/core/mcp/client/sse_client.py @@ -8,15 +8,21 @@ from typing import Any from urllib.parse import urljoin, urlparse import httpx -from httpx_sse import connect_sse from sseclient import SSEClient +from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect from core.mcp import types +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.types import SessionMessage -from core.mcp.utils import create_mcp_http_client, remove_request_params logger = logging.getLogger(__name__) +DEFAULT_QUEUE_READ_TIMEOUT = 3 + + +def remove_request_params(url: str) -> str: + return urljoin(url, urlparse(url).path) + @contextmanager def sse_client( @@ -40,9 +46,9 @@ def sse_client( with ThreadPoolExecutor() as executor: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - with create_mcp_http_client(headers=headers) as client: - with connect_sse( - client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + with create_ssrf_proxy_mcp_http_client(headers=headers) as client: + with ssrf_proxy_sse_connect( + url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client ) as event_source: event_source.response.raise_for_status() logger.debug("SSE connection established") @@ -94,7 +100,7 @@ def sse_client( try: while not cancel_event.is_set(): try: - message = write_queue.get(timeout=5) + message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if message is None: break response = client.post( @@ -130,6 +136,10 @@ def sse_client( yield read_queue, write_queue finally: cancel_event.set() + except httpx.HTTPStatusError as exc: + if exc.response.status_code == 401: + raise MCPAuthError() + raise MCPConnectionError() except Exception as exc: logger.exception("Error connecting to SSE endpoint") raise exc diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py index 048123c4f2..1f46ee948f 100644 --- a/api/core/mcp/client/streamable_client.py +++ b/api/core/mcp/client/streamable_client.py @@ -17,8 +17,9 @@ from datetime import timedelta from typing import Any, cast import httpx -from httpx_sse import EventSource, ServerSentEvent, connect_sse +from httpx_sse import EventSource, ServerSentEvent +from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect from core.mcp.types import ( ClientMessageMetadata, ErrorData, @@ -30,7 +31,6 @@ from core.mcp.types import ( RequestId, SessionMessage, ) -from core.mcp.utils import create_mcp_http_client logger = logging.getLogger(__name__) @@ -50,6 +50,8 @@ ACCEPT = "Accept" JSON = "application/json" SSE = "text/event-stream" +DEFAULT_QUEUE_READ_TIMEOUT = 3 + class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -184,12 +186,13 @@ class StreamableHTTPTransport: headers = self._update_headers_with_session(self.request_headers) - with connect_sse( - client, - "GET", + with ssrf_proxy_sse_connect( self.url, + 2, headers=headers, timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), + client=client, + method="GET", ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") @@ -215,12 +218,13 @@ class StreamableHTTPTransport: if isinstance(ctx.session_message.message.root, JSONRPCRequest): original_request_id = ctx.session_message.message.root.id - with connect_sse( - ctx.client, - "GET", + with ssrf_proxy_sse_connect( self.url, + 2, headers=headers, timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), + client=ctx.client, + method="GET", ) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") @@ -304,7 +308,6 @@ class StreamableHTTPTransport: resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), ) except Exception as e: - logger.exception("Error reading SSE stream:") ctx.server_to_client_queue.put(e) def _handle_unexpected_content_type( @@ -346,7 +349,7 @@ class StreamableHTTPTransport: while not self.stop_event.is_set(): try: # Read message from client queue with timeout to check stop_event periodically - session_message = client_to_server_queue.get(timeout=5) + session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) if session_message is None: break @@ -444,7 +447,7 @@ def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - with create_mcp_http_client( + with create_ssrf_proxy_mcp_http_client( headers=transport.request_headers, timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), ) as client: diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index 61c6457f2b..56aa16f873 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -1,6 +1,8 @@ import logging +from collections.abc import Callable from contextlib import ExitStack from typing import Optional, cast +from urllib.parse import urlparse from core.mcp.client.sse_client import sse_client from core.mcp.client.streamable_client import streamablehttp_client @@ -59,42 +61,53 @@ class MCPClient: first_try: bool = True, ): """Initialize the client with fallback to SSE if streamable connection fails""" - connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)] + connection_methods = {"mcp": streamablehttp_client, "sse": sse_client} + + parsed_url = urlparse(self.server_url) + path = parsed_url.path + method_name = path.rstrip("/").split("/")[-1] if path else "" + try: + client_factory = connection_methods[method_name] + self.connect_server(client_factory, method_name) + except KeyError: + try: + self.connect_server(streamablehttp_client, "sse") + except MCPConnectionError: + self.connect_server(sse_client, "mcp") + + def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True): from core.mcp.auth.auth_flow import auth - for method_name, client_factory in connection_methods: - try: - headers = ( - {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} - if self.authed and self.token - else {} - ) - self._streams_context = client_factory(url=self.server_url, headers=headers) - if method_name == "streamablehttp_client": - read_stream, write_stream, _ = self._streams_context.__enter__() - streams = (read_stream, write_stream) - else: # sse_client - streams = self._streams_context.__enter__() + try: + headers = ( + {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} + if self.authed and self.token + else {} + ) + self._streams_context = client_factory(url=self.server_url, headers=headers) + if method_name == "mcp": + read_stream, write_stream, _ = self._streams_context.__enter__() + streams = (read_stream, write_stream) + else: # sse_client + streams = self._streams_context.__enter__() - self._session_context = ClientSession(*streams) - self._session = self._session_context.__enter__() - session = cast(ClientSession, self._session) - session.initialize() - return + self._session_context = ClientSession(*streams) + self._session = self._session_context.__enter__() + session = cast(ClientSession, self._session) + session.initialize() + return - except MCPAuthError: - if not self.authed: - raise - - auth(self.provider, self.server_url, self.authorization_code, self.scope) - if first_try: - return self._initialize(first_try=False) - - except MCPConnectionError: - if method_name == "streamablehttp_client": - continue + except MCPAuthError: + if not self.authed: raise + auth(self.provider, self.server_url, self.authorization_code, self.scope) + if first_try: + return self.connect_server(client_factory, method_name, first_try=False) + + except MCPConnectionError: + raise + def list_tools(self) -> list[Tool]: """Connect to an MCP server running with SSE transport""" # List available tools to verify connection diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 6cf6df4789..978f40be5b 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -40,7 +40,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) -DEFAULT_RESPONSE_READ_TIMEOUT = 5 +DEFAULT_RESPONSE_READ_TIMEOUT = 1 class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -210,7 +210,7 @@ class BaseSession( future.result(timeout=remaining) except Exception as e: - print(f"Error waiting for task: {e}") + logging.exception(f"Error waiting for task: {e}") def send_request( self, @@ -247,8 +247,12 @@ class BaseSession( timeout = request_read_timeout_seconds.total_seconds() elif self._session_read_timeout_seconds is not None: timeout = self._session_read_timeout_seconds.total_seconds() - - response_or_error = response_queue.get(timeout=timeout) + while not self._stop_event.is_set(): + try: + response_or_error = response_queue.get(timeout=timeout) + break + except queue.Empty: + continue if response_or_error is None: raise MCPConnectionError( @@ -315,7 +319,7 @@ class BaseSession( while not self._stop_event.is_set(): try: # Attempt to receive a message (this would be blocking in a synchronous context) - message = self._read_stream.get(timeout=5) + message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT) if message is None: break if isinstance(message, HTTPStatusError): diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py deleted file mode 100644 index c6e9dd21ac..0000000000 --- a/api/core/mcp/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any -from urllib.parse import urljoin, urlparse - -import httpx - - -def create_mcp_http_client( - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, -) -> httpx.Client: - kwargs: dict[str, Any] = { - "follow_redirects": True, - } - - # Handle timeout - if timeout is None: - kwargs["timeout"] = httpx.Timeout(30.0) - else: - kwargs["timeout"] = timeout - - # Handle headers - if headers is not None: - kwargs["headers"] = headers - return httpx.Client(**kwargs) - - -def remove_request_params(url: str) -> str: - return urljoin(url, urlparse(url).path)