This commit is contained in:
Novice 2025-05-23 18:12:47 +08:00
parent 8464ad0b35
commit 1fd4839eca
7 changed files with 174 additions and 83 deletions

View File

@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
return jsonable_encoder( return jsonable_encoder(
@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
return jsonable_encoder( return jsonable_encoder(
@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource):
authed=False, authed=False,
authorization_code=args["authorization_code"], authorization_code=args["authorization_code"],
): ):
MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
credentials={},
authed=True,
)
return {"result": "success"} return {"result": "success"}
except MCPAuthError as e:
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id) auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"]) return auth(auth_provider, provider.server_url, args["authorization_code"])

View File

@ -108,3 +108,85 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs) return make_request("HEAD", url, max_retries=max_retries, **kwargs)
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
client_kwargs = {
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
}
if dify_config.SSRF_PROXY_ALL_URL:
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
return httpx.Client(**client_kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
client_kwargs["mounts"] = proxy_mounts
return httpx.Client(**client_kwargs)
else:
return httpx.Client(**client_kwargs)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception as e:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise

View File

@ -8,15 +8,21 @@ from typing import Any
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import httpx import httpx
from httpx_sse import connect_sse
from sseclient import SSEClient from sseclient import SSEClient
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp import types from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage from core.mcp.types import SessionMessage
from core.mcp.utils import create_mcp_http_client, remove_request_params
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@contextmanager @contextmanager
def sse_client( def sse_client(
@ -40,9 +46,9 @@ def sse_client(
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_mcp_http_client(headers=headers) as client: with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
with connect_sse( with ssrf_proxy_sse_connect(
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout) url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("SSE connection established") logger.debug("SSE connection established")
@ -94,7 +100,7 @@ def sse_client(
try: try:
while not cancel_event.is_set(): while not cancel_event.is_set():
try: try:
message = write_queue.get(timeout=5) message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None: if message is None:
break break
response = client.post( response = client.post(
@ -130,6 +136,10 @@ def sse_client(
yield read_queue, write_queue yield read_queue, write_queue
finally: finally:
cancel_event.set() cancel_event.set()
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception as exc: except Exception as exc:
logger.exception("Error connecting to SSE endpoint") logger.exception("Error connecting to SSE endpoint")
raise exc raise exc

View File

@ -17,8 +17,9 @@ from datetime import timedelta
from typing import Any, cast from typing import Any, cast
import httpx import httpx
from httpx_sse import EventSource, ServerSentEvent, connect_sse from httpx_sse import EventSource, ServerSentEvent
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp.types import ( from core.mcp.types import (
ClientMessageMetadata, ClientMessageMetadata,
ErrorData, ErrorData,
@ -30,7 +31,6 @@ from core.mcp.types import (
RequestId, RequestId,
SessionMessage, SessionMessage,
) )
from core.mcp.utils import create_mcp_http_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,6 +50,8 @@ ACCEPT = "Accept"
JSON = "application/json" JSON = "application/json"
SSE = "text/event-stream" SSE = "text/event-stream"
DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception): class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors.""" """Base exception for StreamableHTTP transport errors."""
@ -184,12 +186,13 @@ class StreamableHTTPTransport:
headers = self._update_headers_with_session(self.request_headers) headers = self._update_headers_with_session(self.request_headers)
with connect_sse( with ssrf_proxy_sse_connect(
client,
"GET",
self.url, self.url,
2,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("GET SSE connection established") logger.debug("GET SSE connection established")
@ -215,12 +218,13 @@ class StreamableHTTPTransport:
if isinstance(ctx.session_message.message.root, JSONRPCRequest): if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id original_request_id = ctx.session_message.message.root.id
with connect_sse( with ssrf_proxy_sse_connect(
ctx.client,
"GET",
self.url, self.url,
2,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established") logger.debug("Resumption GET SSE connection established")
@ -304,7 +308,6 @@ class StreamableHTTPTransport:
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
) )
except Exception as e: except Exception as e:
logger.exception("Error reading SSE stream:")
ctx.server_to_client_queue.put(e) ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type( def _handle_unexpected_content_type(
@ -346,7 +349,7 @@ class StreamableHTTPTransport:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
# Read message from client queue with timeout to check stop_event periodically # Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=5) session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None: if session_message is None:
break break
@ -444,7 +447,7 @@ def streamablehttp_client(
try: try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}") logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
with create_mcp_http_client( with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client: ) as client:

View File

@ -1,6 +1,8 @@
import logging import logging
from collections.abc import Callable
from contextlib import ExitStack from contextlib import ExitStack
from typing import Optional, cast from typing import Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client from core.mcp.client.streamable_client import streamablehttp_client
@ -59,10 +61,23 @@ class MCPClient:
first_try: bool = True, first_try: bool = True,
): ):
"""Initialize the client with fallback to SSE if streamable connection fails""" """Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)] connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(streamablehttp_client, "sse")
except MCPConnectionError:
self.connect_server(sse_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth
for method_name, client_factory in connection_methods:
try: try:
headers = ( headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
@ -70,7 +85,7 @@ class MCPClient:
else {} else {}
) )
self._streams_context = client_factory(url=self.server_url, headers=headers) self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "streamablehttp_client": if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__() read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream) streams = (read_stream, write_stream)
else: # sse_client else: # sse_client
@ -88,11 +103,9 @@ class MCPClient:
auth(self.provider, self.server_url, self.authorization_code, self.scope) auth(self.provider, self.server_url, self.authorization_code, self.scope)
if first_try: if first_try:
return self._initialize(first_try=False) return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError: except MCPConnectionError:
if method_name == "streamablehttp_client":
continue
raise raise
def list_tools(self) -> list[Tool]: def list_tools(self) -> list[Tool]:

View File

@ -40,7 +40,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 5 DEFAULT_RESPONSE_READ_TIMEOUT = 1
class RequestResponder(Generic[ReceiveRequestT, SendResultT]): class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@ -210,7 +210,7 @@ class BaseSession(
future.result(timeout=remaining) future.result(timeout=remaining)
except Exception as e: except Exception as e:
print(f"Error waiting for task: {e}") logging.exception(f"Error waiting for task: {e}")
def send_request( def send_request(
self, self,
@ -247,8 +247,12 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds() timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None: elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds() timeout = self._session_read_timeout_seconds.total_seconds()
while not self._stop_event.is_set():
try:
response_or_error = response_queue.get(timeout=timeout) response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
continue
if response_or_error is None: if response_or_error is None:
raise MCPConnectionError( raise MCPConnectionError(
@ -315,7 +319,7 @@ class BaseSession(
while not self._stop_event.is_set(): while not self._stop_event.is_set():
try: try:
# Attempt to receive a message (this would be blocking in a synchronous context) # Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=5) message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None: if message is None:
break break
if isinstance(message, HTTPStatusError): if isinstance(message, HTTPStatusError):

View File

@ -1,28 +0,0 @@
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
kwargs: dict[str, Any] = {
"follow_redirects": True,
}
# Handle timeout
if timeout is None:
kwargs["timeout"] = httpx.Timeout(30.0)
else:
kwargs["timeout"] = timeout
# Handle headers
if headers is not None:
kwargs["headers"] = headers
return httpx.Client(**kwargs)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)