Merge branch 'feat/mcp' into deploy/dev

This commit is contained in:
Novice 2025-05-27 13:15:25 +08:00
commit e237bc09b8
18 changed files with 335 additions and 232 deletions

View File

@ -68,16 +68,31 @@ class AppMCPServerController(Resource):
parser.add_argument("id", type=str, required=True, location="json") parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json") parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise Forbidden() raise Forbidden()
server.description = args["description"] server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False) server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
server.status = AppMCPServerStatus(args["status"]) db.session.commit()
return server
class AppMCPServerRefreshController(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
def get(self, server_id):
if not current_user.is_editor:
raise Forbidden()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
if not server:
raise Forbidden()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit() db.session.commit()
return server return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server") api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")

View File

@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
return jsonable_encoder( return jsonable_encoder(
@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json") parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json") parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json") parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json") parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
return jsonable_encoder( return jsonable_encoder(
@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource):
authed=False, authed=False,
authorization_code=args["authorization_code"], authorization_code=args["authorization_code"],
): ):
MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
credentials={},
authed=True,
)
return {"result": "success"} return {"result": "success"}
except MCPAuthError as e:
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id) auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"]) return auth(auth_provider, provider.server_url, args["authorization_code"])

View File

@ -18,6 +18,7 @@ from controllers.web.error import (
) )
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_config.entities import VariableEntity
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
@ -175,11 +176,30 @@ class ChatMCPApi(Resource):
app = db.session.query(App).filter(App.id == server.app_id).first() app = db.session.query(App).filter(App.id == server.app_id).first()
if not app: if not app:
raise NotFound("App Not Found") raise NotFound("App Not Found")
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
try:
user_input_form = [VariableEntity.model_validate(item) for item in user_input_form]
except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}")
try: try:
request = ClientRequest.model_validate(args) request = ClientRequest.model_validate(args)
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}") raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request) mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
return helper.compact_generate_response(mcp_server_handler.handle()) return helper.compact_generate_response(mcp_server_handler.handle())

View File

@ -108,3 +108,86 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs) return make_request("HEAD", url, max_retries=max_retries, **kwargs)
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
client_kwargs = {
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
"follow_redirects": True, # Enable redirect following for MCP connections
}
if dify_config.SSRF_PROXY_ALL_URL:
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
return httpx.Client(**client_kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
client_kwargs["mounts"] = proxy_mounts
return httpx.Client(**client_kwargs)
else:
return httpx.Client(**client_kwargs)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception as e:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise

View File

@ -59,7 +59,6 @@ def start_authorization(
metadata: Optional[OAuthMetadata], metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation, client_information: OAuthClientInformation,
redirect_url: str, redirect_url: str,
scope: Optional[str] = None,
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Begins the authorization flow.""" """Begins the authorization flow."""
response_type = "code" response_type = "code"
@ -85,11 +84,9 @@ def start_authorization(
"code_challenge": code_challenge, "code_challenge": code_challenge,
"code_challenge_method": code_challenge_method, "code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url, "redirect_uri": redirect_url,
"state": "/tools?provider_id=" + client_information.client_id,
} }
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier return authorization_url, code_verifier
@ -187,7 +184,6 @@ def auth(
provider: OAuthClientProvider, provider: OAuthClientProvider,
server_url: str, server_url: str,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
scope: Optional[str] = None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Orchestrates the full auth flow with a server.""" """Orchestrates the full auth flow with a server."""
metadata = discover_oauth_metadata(server_url) metadata = discover_oauth_metadata(server_url)
@ -233,7 +229,6 @@ def auth(
metadata, metadata,
client_information, client_information,
provider.redirect_url, provider.redirect_url,
scope or provider.client_metadata.scope,
) )
provider.save_code_verifier(code_verifier) provider.save_code_verifier(code_verifier)

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from configs.app_config import DifyConfig from configs import dify_config
from core.mcp.types import ( from core.mcp.types import (
OAuthClientInformation, OAuthClientInformation,
OAuthClientInformationFull, OAuthClientInformationFull,
@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0" LATEST_PROTOCOL_VERSION = "1.0"
dify_config = DifyConfig()
class OAuthClientProvider: class OAuthClientProvider:
provider_id: str provider_id: str
@ -25,7 +23,7 @@ class OAuthClientProvider:
@property @property
def redirect_url(self) -> str: def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization.""" """The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL return dify_config.CONSOLE_WEB_URL + "/tools"
@property @property
def client_metadata(self) -> OAuthClientMetadata: def client_metadata(self) -> OAuthClientMetadata:
@ -37,7 +35,6 @@ class OAuthClientProvider:
response_types=["code"], response_types=["code"],
client_name="Dify", client_name="Dify",
client_uri="https://github.com/langgenius/dify", client_uri="https://github.com/langgenius/dify",
scope="read write",
) )
def client_information(self) -> Optional[OAuthClientInformation]: def client_information(self) -> Optional[OAuthClientInformation]:
@ -91,7 +88,3 @@ class OAuthClientProvider:
if not mcp_provider: if not mcp_provider:
return "" return ""
return mcp_provider.credentials.get("code_verifier", "") return mcp_provider.credentials.get("code_verifier", "")
class UnauthorizedError(Exception):
pass

View File

@ -1,6 +1,5 @@
import logging import logging
import queue import queue
import threading
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
@ -8,15 +7,21 @@ from typing import Any
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import httpx import httpx
from httpx_sse import connect_sse
from sseclient import SSEClient from sseclient import SSEClient
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp import types from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage from core.mcp.types import SessionMessage
from core.mcp.utils import create_mcp_http_client, remove_request_params
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@contextmanager @contextmanager
def sse_client( def sse_client(
@ -36,23 +41,19 @@ def sse_client(
read_queue = queue.Queue() read_queue = queue.Queue()
write_queue = queue.Queue() write_queue = queue.Queue()
status_queue = queue.Queue() status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_mcp_http_client(headers=headers) as client: with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
with connect_sse( with ssrf_proxy_sse_connect(
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout) url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue): def sse_reader(status_queue: queue.Queue):
try: try:
while not cancel_event.is_set():
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if cancel_event.is_set():
break
match sse.event: match sse.event:
case "endpoint": case "endpoint":
endpoint_url = urljoin(url, sse.data) endpoint_url = urljoin(url, sse.data)
@ -68,7 +69,7 @@ def sse_client(
f"Endpoint origin does not match connection origin: {endpoint_url}" f"Endpoint origin does not match connection origin: {endpoint_url}"
) )
logger.error(error_msg) logger.error(error_msg)
raise ValueError(error_msg) status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url)) status_queue.put(("ready", endpoint_url))
case "message": case "message":
try: try:
@ -82,19 +83,18 @@ def sse_client(
read_queue.put(session_message) read_queue.put(session_message)
case _: case _:
logger.warning(f"Unknown SSE event: {sse.event}") logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc: except Exception as exc:
if not cancel_event.is_set():
logger.exception("Error reading SSE messages")
read_queue.put(exc) read_queue.put(exc)
finally: finally:
read_queue.put(None) read_queue.put(None)
def post_writer(endpoint_url: str): def post_writer(endpoint_url: str):
try: try:
while not cancel_event.is_set(): while True:
try: try:
message = write_queue.get(timeout=5) message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None: if message is None:
break break
response = client.post( response = client.post(
@ -107,14 +107,13 @@ def sse_client(
) )
response.raise_for_status() response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}") logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty: except queue.Empty:
if cancel_event.is_set():
break
continue continue
except Exception: except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages") logger.exception("Error writing messages")
write_queue.put(exc)
finally: finally:
write_queue.put(None) write_queue.put(None)
@ -125,11 +124,16 @@ def sse_client(
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if status != "ready": if status != "ready":
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if status == "error":
raise endpoint_url
executor.submit(post_writer, endpoint_url) executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue yield read_queue, write_queue
finally:
cancel_event.set() except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception as exc: except Exception as exc:
logger.exception("Error connecting to SSE endpoint") logger.exception("Error connecting to SSE endpoint")
raise exc raise exc

View File

@ -8,7 +8,6 @@ and session management.
import logging import logging
import queue import queue
import threading
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
@ -17,8 +16,9 @@ from datetime import timedelta
from typing import Any, cast from typing import Any, cast
import httpx import httpx
from httpx_sse import EventSource, ServerSentEvent, connect_sse from httpx_sse import EventSource, ServerSentEvent
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp.types import ( from core.mcp.types import (
ClientMessageMetadata, ClientMessageMetadata,
ErrorData, ErrorData,
@ -30,7 +30,6 @@ from core.mcp.types import (
RequestId, RequestId,
SessionMessage, SessionMessage,
) )
from core.mcp.utils import create_mcp_http_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,6 +49,8 @@ ACCEPT = "Accept"
JSON = "application/json" JSON = "application/json"
SSE = "text/event-stream" SSE = "text/event-stream"
DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception): class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors.""" """Base exception for StreamableHTTP transport errors."""
@ -104,11 +105,6 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON, CONTENT_TYPE: JSON,
**self.headers, **self.headers,
} }
self.stop_event = threading.Event()
def stop(self):
"""Signal to stop all operations."""
self.stop_event.set()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available.""" """Update headers with session ID if available."""
@ -168,6 +164,9 @@ class StreamableHTTPTransport:
# Put exception in queue that goes to client # Put exception in queue that goes to client
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
return False return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else: else:
logger.warning(f"Unknown SSE event: {sse.event}") logger.warning(f"Unknown SSE event: {sse.event}")
return False return False
@ -184,19 +183,18 @@ class StreamableHTTPTransport:
headers = self._update_headers_with_session(self.request_headers) headers = self._update_headers_with_session(self.request_headers)
with connect_sse( with ssrf_proxy_sse_connect(
client,
"GET",
self.url, self.url,
2,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("GET SSE connection established") logger.debug("GET SSE connection established")
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(sse, server_to_client_queue) self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc: except Exception as exc:
@ -215,19 +213,18 @@ class StreamableHTTPTransport:
if isinstance(ctx.session_message.message.root, JSONRPCRequest): if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id original_request_id = ctx.session_message.message.root.id
with connect_sse( with ssrf_proxy_sse_connect(
ctx.client,
"GET",
self.url, self.url,
2,
headers=headers, headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds), timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established") logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
is_complete = self._handle_sse_event( is_complete = self._handle_sse_event(
sse, sse,
ctx.server_to_client_queue, ctx.server_to_client_queue,
@ -296,15 +293,14 @@ class StreamableHTTPTransport:
try: try:
event_source = EventSource(response) event_source = EventSource(response)
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if self.stop_event.is_set(): is_complete = self._handle_sse_event(
break
self._handle_sse_event(
sse, sse,
ctx.server_to_client_queue, ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
) )
if is_complete:
break
except Exception as e: except Exception as e:
logger.exception("Error reading SSE stream:")
ctx.server_to_client_queue.put(e) ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type( def _handle_unexpected_content_type(
@ -343,10 +339,10 @@ class StreamableHTTPTransport:
This method processes messages from the client_to_server_queue and sends them to the server. This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue. Responses are written to the server_to_client_queue.
""" """
while not self.stop_event.is_set(): while True:
try: try:
# Read message from client queue with timeout to check stop_event periodically # Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=5) session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None: if session_message is None:
break break
@ -379,10 +375,8 @@ class StreamableHTTPTransport:
else: else:
self._handle_post_request(ctx) self._handle_post_request(ctx)
except queue.Empty: except queue.Empty:
# Timeout - continue loop to check stop_event
continue continue
except Exception as exc: except Exception as exc:
# Send exception to client
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None: def terminate_session(self, client: httpx.Client) -> None:
@ -444,7 +438,7 @@ def streamablehttp_client(
try: try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}") logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
with create_mcp_http_client( with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers, headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds), timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client: ) as client:
@ -475,9 +469,6 @@ def streamablehttp_client(
# Signal threads to stop # Signal threads to stop
client_to_server_queue.put(None) client_to_server_queue.put(None)
finally: finally:
# Clean up
transport.stop()
# Clear any remaining items and add None sentinel to unblock any waiting threads # Clear any remaining items and add None sentinel to unblock any waiting threads
try: try:
while not client_to_server_queue.empty(): while not client_to_server_queue.empty():

View File

@ -1,6 +1,8 @@
import logging import logging
from collections.abc import Callable
from contextlib import ExitStack from contextlib import ExitStack
from typing import Optional, cast from typing import Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client from core.mcp.client.streamable_client import streamablehttp_client
@ -19,7 +21,6 @@ class MCPClient:
tenant_id: str, tenant_id: str,
authed: bool = True, authed: bool = True,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
scope: Optional[str] = None,
): ):
# Initialize info # Initialize info
self.provider_id = provider_id self.provider_id = provider_id
@ -30,7 +31,6 @@ class MCPClient:
# Authentication info # Authentication info
self.authed = authed self.authed = authed
self.authorization_code = authorization_code self.authorization_code = authorization_code
self.scope = scope
if authed: if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
@ -47,7 +47,7 @@ class MCPClient:
self._initialized = False self._initialized = False
def __enter__(self): def __enter__(self):
self._initialize(first_try=True) self._initialize()
self._initialized = True self._initialized = True
return self return self
@ -56,13 +56,25 @@ class MCPClient:
def _initialize( def _initialize(
self, self,
first_try: bool = True,
): ):
"""Initialize the client with fallback to SSE if streamable connection fails""" """Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)] connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(sse_client, "sse")
except MCPConnectionError:
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth
for method_name, client_factory in connection_methods:
try: try:
headers = ( headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"} {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
@ -70,7 +82,7 @@ class MCPClient:
else {} else {}
) )
self._streams_context = client_factory(url=self.server_url, headers=headers) self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "streamablehttp_client": if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__() read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream) streams = (read_stream, write_stream)
else: # sse_client else: # sse_client
@ -85,14 +97,12 @@ class MCPClient:
except MCPAuthError: except MCPAuthError:
if not self.authed: if not self.authed:
raise raise
auth(self.provider, self.server_url, self.authorization_code)
auth(self.provider, self.server_url, self.authorization_code, self.scope) self.token = self.provider.tokens()
if first_try: if first_try:
return self._initialize(first_try=False) return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError: except MCPConnectionError:
if method_name == "streamablehttp_client":
continue
raise raise
def list_tools(self) -> list[Tool]: def list_tools(self) -> list[Tool]:
@ -121,5 +131,6 @@ class MCPClient:
self._session = None self._session = None
self._initialized = False self._initialized = False
self.exit_stack.close() self.exit_stack.close()
except Exception: except Exception as e:
logging.exception("Error during cleanup") logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,30 +2,31 @@ import json
from collections.abc import Mapping from collections.abc import Mapping
from typing import cast from typing import cast
from configs.app_config import DifyConfig from configs import dify_config
from controllers.web.passport import generate_session_id from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, EndUser from models.model import App, AppMCPServer, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
""" """
Apply to MCP HTTP streamable server with stateless http Apply to MCP HTTP streamable server with stateless http
""" """
dify_config = DifyConfig()
class MCPServerReuqestHandler: class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest): def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
self.app = app self.app = app
self.request = request self.request = request
if not self.app.mcp_server: self.mcp_server: AppMCPServer = self.app.mcp_server
if not self.mcp_server:
raise ValueError("MCP server not found") raise ValueError("MCP server not found")
self.mcp_server = self.app.mcp_server
self.end_user = self.retrieve_end_user() self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@property @property
def request_type(self): def request_type(self):
@ -33,6 +34,7 @@ class MCPServerReuqestHandler:
@property @property
def parameter_schema(self): def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
@ -41,10 +43,11 @@ class MCPServerReuqestHandler:
"type": "object", "type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501 "description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {}, "default": {},
# TODO: add input parameters "properties": parameters,
"required": required,
}, },
}, },
"required": ["query"], "required": "query",
} }
@property @property
@ -152,3 +155,25 @@ class MCPServerReuqestHandler:
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"]
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
return parameters, required

View File

@ -1,7 +1,5 @@
import logging import logging
import queue import queue
import threading
import time
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack from contextlib import ExitStack
@ -40,7 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 5 DEFAULT_RESPONSE_READ_TIMEOUT = 1
class RequestResponder(Generic[ReceiveRequestT, SendResultT]): class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._completed = False self._completed = False
self._on_complete = on_complete self._on_complete = on_complete
self._entered = False # Track if we're in a context manager self._entered = False # Track if we're in a context manager
self._cancel_event = threading.Event()
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking.""" """Enter the context manager, enabling request cancellation tracking."""
self._entered = True self._entered = True
self._cancel_event = threading.Event()
self._cancel_event.clear()
return self return self
def __exit__( def __exit__(
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._on_complete(self) self._on_complete(self)
finally: finally:
self._entered = False self._entered = False
if not self._cancel_event:
raise RuntimeError("No active cancel scope")
self._cancel_event.set()
def respond(self, response: SendResultT | ErrorData) -> None: def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request. """Send a response for this request.
@ -117,7 +109,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to" assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True self._completed = True
self._session._send_response(request_id=self.request_id, response=response) self._session._send_response(request_id=self.request_id, response=response)
@ -127,7 +118,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
if not self._entered: if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
self._cancel_event.set()
self._completed = True # Mark as completed so it's removed from in_flight self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation # Send an error response to indicate cancellation
self._session._send_response( self._session._send_response(
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
response=ErrorData(code=0, message="Request cancelled", data=None), response=ErrorData(code=0, message="Request cancelled", data=None),
) )
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_event.is_set()
class BaseSession( class BaseSession(
Generic[ Generic[
@ -184,11 +166,9 @@ class BaseSession(
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack() self._exit_stack = ExitStack()
self._futures = [] self._futures = []
self._request_id_lock = threading.Lock()
def __enter__(self) -> Self: def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor() self._executor = ThreadPoolExecutor()
self._stop_event = threading.Event()
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
@ -196,21 +176,8 @@ class BaseSession(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ) -> None:
self._exit_stack.close() self._exit_stack.close()
self._stop_event.set() self._read_stream.put(None)
self._wait_for_futures(timeout=5) self._write_stream.put(None)
def _wait_for_futures(self, timeout=None):
end_time = time.time() + timeout if timeout else None
for future in list(self._futures):
try:
remaining = end_time - time.time() if end_time else None
if remaining is not None and remaining <= 0:
break
future.result(timeout=remaining)
except Exception as e:
print(f"Error waiting for task: {e}")
def send_request( def send_request(
self, self,
@ -247,8 +214,12 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds() timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None: elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds() timeout = self._session_read_timeout_seconds.total_seconds()
while True:
try:
response_or_error = response_queue.get(timeout=timeout) response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
continue
if response_or_error is None: if response_or_error is None:
raise MCPConnectionError( raise MCPConnectionError(
@ -312,10 +283,10 @@ class BaseSession(
Main message processing loop. Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread. In a real synchronous implementation, this would likely run in a separate thread.
""" """
while not self._stop_event.is_set(): while True:
try: try:
# Attempt to receive a message (this would be blocking in a synchronous context) # Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=5) message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None: if message is None:
break break
if isinstance(message, HTTPStatusError): if isinstance(message, HTTPStatusError):
@ -374,12 +345,9 @@ class BaseSession(
else: else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
except queue.Empty: except queue.Empty:
if self._stop_event.is_set():
break
continue continue
except Exception as e: except Exception as e:
logging.exception("Error in message processing loop") logging.exception("Error in message processing loop")
self._stop_event.set()
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
""" """

View File

@ -3,12 +3,11 @@ from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter from pydantic import AnyUrl, TypeAdapter
from configs.app_config import DifyConfig from configs import dify_config
from core.mcp import types from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder from core.mcp.session.base_session import BaseSession, RequestResponder
dify_config = DifyConfig()
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION) DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)

View File

@ -1177,7 +1177,6 @@ class SessionMessage:
class OAuthClientMetadata(BaseModel): class OAuthClientMetadata(BaseModel):
client_name: str client_name: str
redirect_uris: list[str] redirect_uris: list[str]
scope: str
grant_types: Optional[list[str]] = None grant_types: Optional[list[str]] = None
response_types: Optional[list[str]] = None response_types: Optional[list[str]] = None
token_endpoint_auth_method: Optional[str] = None token_endpoint_auth_method: Optional[str] = None

View File

@ -1,28 +0,0 @@
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
kwargs: dict[str, Any] = {
"follow_redirects": True,
}
# Handle timeout
if timeout is None:
kwargs["timeout"] = httpx.Timeout(30.0)
else:
kwargs["timeout"] = timeout
# Handle headers
if headers is not None:
kwargs["headers"] = headers
return httpx.Client(**kwargs)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Literal, Optional from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files" parameter["type"] = "files"
# ------------- # -------------
optional_fields = self.optional_field("server_url", self.server_url)
return { return {
"id": self.id, "id": self.id,
"author": self.author, "author": self.author,
@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel):
"allow_delete": self.allow_delete, "allow_delete": self.allow_delete,
"tools": tools, "tools": tools,
"labels": self.labels, "labels": self.labels,
**optional_fields,
} }
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}

View File

@ -1,6 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent from core.mcp.types import ImageContent, TextContent
from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -37,9 +38,14 @@ class MCPTool(Tool):
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters) tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
for content in result.content: for content in result.content:
if isinstance(content, TextContent): if isinstance(content, TextContent):
yield self.create_text_message(content.text) yield self.create_text_message(content.text)

View File

@ -1471,6 +1471,10 @@ class AppMCPServer(Base):
return result return result
@property
def parameters_dict(self) -> dict[str, Any]:
return json.loads(self.parameters)
class Site(Base): class Site(Base):
__tablename__ = "sites" __tablename__ = "sites"

View File

@ -1,5 +1,6 @@
import json import json
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -58,8 +59,13 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None: if mcp_provider is None:
raise ValueError("MCP tool not found") raise ValueError("MCP tool not found")
try:
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools() tools = mcp_client.list_tools()
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True mcp_provider.authed = True
db.session.commit() db.session.commit()