mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 17:35:55 +08:00
temp
This commit is contained in:
parent
8464ad0b35
commit
1fd4839eca
@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
return jsonable_encoder(
|
||||
@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
return jsonable_encoder(
|
||||
@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource):
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credentials={},
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
except MCPAuthError as e:
|
||||
|
||||
except MCPAuthError:
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id)
|
||||
return auth(auth_provider, provider.server_url, args["authorization_code"])
|
||||
|
||||
|
@ -108,3 +108,85 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
client_kwargs = {
|
||||
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
"headers": headers or {},
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
|
||||
return httpx.Client(**client_kwargs)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
client_kwargs["mounts"] = proxy_mounts
|
||||
return httpx.Client(**client_kwargs)
|
||||
else:
|
||||
return httpx.Client(**client_kwargs)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
max_retries: Maximum number of retry attempts
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception as e:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
|
@ -8,15 +8,21 @@ from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_mcp_http_client, remove_request_params
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
@ -40,9 +46,9 @@ def sse_client(
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
with create_mcp_http_client(headers=headers) as client:
|
||||
with connect_sse(
|
||||
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
|
||||
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
@ -94,7 +100,7 @@ def sse_client(
|
||||
try:
|
||||
while not cancel_event.is_set():
|
||||
try:
|
||||
message = write_queue.get(timeout=5)
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
response = client.post(
|
||||
@ -130,6 +136,10 @@ def sse_client(
|
||||
yield read_queue, write_queue
|
||||
finally:
|
||||
cancel_event.set()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception as exc:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise exc
|
||||
|
@ -17,8 +17,9 @@ from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent, connect_sse
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
@ -30,7 +31,6 @@ from core.mcp.types import (
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_mcp_http_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,6 +50,8 @@ ACCEPT = "Accept"
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
@ -184,12 +186,13 @@ class StreamableHTTPTransport:
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
with connect_sse(
|
||||
client,
|
||||
"GET",
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
2,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
@ -215,12 +218,13 @@ class StreamableHTTPTransport:
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
with connect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
2,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
@ -304,7 +308,6 @@ class StreamableHTTPTransport:
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error reading SSE stream:")
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
@ -346,7 +349,7 @@ class StreamableHTTPTransport:
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=5)
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
break
|
||||
|
||||
@ -444,7 +447,7 @@ def streamablehttp_client(
|
||||
try:
|
||||
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
|
||||
|
||||
with create_mcp_http_client(
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
) as client:
|
||||
|
@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
@ -59,42 +61,53 @@ class MCPClient:
|
||||
first_try: bool = True,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)]
|
||||
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
try:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
except KeyError:
|
||||
try:
|
||||
self.connect_server(streamablehttp_client, "sse")
|
||||
except MCPConnectionError:
|
||||
self.connect_server(sse_client, "mcp")
|
||||
|
||||
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
for method_name, client_factory in connection_methods:
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if method_name == "streamablehttp_client":
|
||||
read_stream, write_stream, _ = self._streams_context.__enter__()
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._streams_context.__enter__()
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._streams_context.__enter__()
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._streams_context.__enter__()
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._session_context.__enter__()
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._session_context.__enter__()
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
|
||||
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
||||
if first_try:
|
||||
return self._initialize(first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
if method_name == "streamablehttp_client":
|
||||
continue
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
|
||||
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
raise
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
|
@ -40,7 +40,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 5
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
@ -210,7 +210,7 @@ class BaseSession(
|
||||
|
||||
future.result(timeout=remaining)
|
||||
except Exception as e:
|
||||
print(f"Error waiting for task: {e}")
|
||||
logging.exception(f"Error waiting for task: {e}")
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
@ -247,8 +247,12 @@ class BaseSession(
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
raise MCPConnectionError(
|
||||
@ -315,7 +319,7 @@ class BaseSession(
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||
message = self._read_stream.get(timeout=5)
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, HTTPStatusError):
|
||||
|
@ -1,28 +0,0 @@
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
def create_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
kwargs: dict[str, Any] = {
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
# Handle timeout
|
||||
if timeout is None:
|
||||
kwargs["timeout"] = httpx.Timeout(30.0)
|
||||
else:
|
||||
kwargs["timeout"] = timeout
|
||||
|
||||
# Handle headers
|
||||
if headers is not None:
|
||||
kwargs["headers"] = headers
|
||||
return httpx.Client(**kwargs)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
Loading…
x
Reference in New Issue
Block a user