This commit is contained in:
Novice 2025-05-23 18:12:47 +08:00
parent 8464ad0b35
commit 1fd4839eca
7 changed files with 174 additions and 83 deletions

View File

@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args()
user = current_user
return jsonable_encoder(
@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return jsonable_encoder(
@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource):
authed=False,
authorization_code=args["authorization_code"],
):
MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
credentials={},
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"])

View File

@ -108,3 +108,85 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
client_kwargs = {
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
}
if dify_config.SSRF_PROXY_ALL_URL:
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
return httpx.Client(**client_kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
client_kwargs["mounts"] = proxy_mounts
return httpx.Client(**client_kwargs)
else:
return httpx.Client(**client_kwargs)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception as e:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise

View File

@ -8,15 +8,21 @@ from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import connect_sse
from sseclient import SSEClient
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage
from core.mcp.utils import create_mcp_http_client, remove_request_params
logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@contextmanager
def sse_client(
@ -40,9 +46,9 @@ def sse_client(
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_mcp_http_client(headers=headers) as client:
with connect_sse(
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
with ssrf_proxy_sse_connect(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
@ -94,7 +100,7 @@ def sse_client(
try:
while not cancel_event.is_set():
try:
message = write_queue.get(timeout=5)
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
response = client.post(
@ -130,6 +136,10 @@ def sse_client(
yield read_queue, write_queue
finally:
cancel_event.set()
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception as exc:
logger.exception("Error connecting to SSE endpoint")
raise exc

View File

@ -17,8 +17,9 @@ from datetime import timedelta
from typing import Any, cast
import httpx
from httpx_sse import EventSource, ServerSentEvent, connect_sse
from httpx_sse import EventSource, ServerSentEvent
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp.types import (
ClientMessageMetadata,
ErrorData,
@ -30,7 +31,6 @@ from core.mcp.types import (
RequestId,
SessionMessage,
)
from core.mcp.utils import create_mcp_http_client
logger = logging.getLogger(__name__)
@ -50,6 +50,8 @@ ACCEPT = "Accept"
JSON = "application/json"
SSE = "text/event-stream"
DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
@ -184,12 +186,13 @@ class StreamableHTTPTransport:
headers = self._update_headers_with_session(self.request_headers)
with connect_sse(
client,
"GET",
with ssrf_proxy_sse_connect(
self.url,
2,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
@ -215,12 +218,13 @@ class StreamableHTTPTransport:
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id
with connect_sse(
ctx.client,
"GET",
with ssrf_proxy_sse_connect(
self.url,
2,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
@ -304,7 +308,6 @@ class StreamableHTTPTransport:
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
except Exception as e:
logger.exception("Error reading SSE stream:")
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
@ -346,7 +349,7 @@ class StreamableHTTPTransport:
while not self.stop_event.is_set():
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=5)
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
break
@ -444,7 +447,7 @@ def streamablehttp_client(
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
with create_mcp_http_client(
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client:

View File

@ -1,6 +1,8 @@
import logging
from collections.abc import Callable
from contextlib import ExitStack
from typing import Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
@ -59,42 +61,53 @@ class MCPClient:
first_try: bool = True,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)]
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(streamablehttp_client, "sse")
except MCPConnectionError:
self.connect_server(sse_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth
for method_name, client_factory in connection_methods:
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "streamablehttp_client":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
session = cast(ClientSession, self._session)
session.initialize()
return
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
session = cast(ClientSession, self._session)
session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
if first_try:
return self._initialize(first_try=False)
except MCPConnectionError:
if method_name == "streamablehttp_client":
continue
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection

View File

@ -40,7 +40,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 5
DEFAULT_RESPONSE_READ_TIMEOUT = 1
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@ -210,7 +210,7 @@ class BaseSession(
future.result(timeout=remaining)
except Exception as e:
print(f"Error waiting for task: {e}")
logging.exception(f"Error waiting for task: {e}")
def send_request(
self,
@ -247,8 +247,12 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
response_or_error = response_queue.get(timeout=timeout)
while not self._stop_event.is_set():
try:
response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
continue
if response_or_error is None:
raise MCPConnectionError(
@ -315,7 +319,7 @@ class BaseSession(
while not self._stop_event.is_set():
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=5)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, HTTPStatusError):

View File

@ -1,28 +0,0 @@
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
kwargs: dict[str, Any] = {
"follow_redirects": True,
}
# Handle timeout
if timeout is None:
kwargs["timeout"] = httpx.Timeout(30.0)
else:
kwargs["timeout"] = timeout
# Handle headers
if headers is not None:
kwargs["headers"] = headers
return httpx.Client(**kwargs)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)