Merge branch 'feat/mcp' into deploy/dev

This commit is contained in:
Novice 2025-05-27 13:15:25 +08:00
commit e237bc09b8
18 changed files with 335 additions and 232 deletions

View File

@ -68,16 +68,31 @@ class AppMCPServerController(Resource):
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=True, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server:
raise Forbidden()
server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
server.status = AppMCPServerStatus(args["status"])
db.session.commit()
return server
class AppMCPServerRefreshController(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
def get(self, server_id):
if not current_user.is_editor:
raise Forbidden()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
if not server:
raise Forbidden()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")

View File

@ -628,7 +628,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
args = parser.parse_args()
user = current_user
return jsonable_encoder(
@ -652,7 +652,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return jsonable_encoder(
@ -704,8 +704,15 @@ class ToolMCPAuthApi(Resource):
authed=False,
authorization_code=args["authorization_code"],
):
MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
credentials={},
authed=True,
)
return {"result": "success"}
except MCPAuthError as e:
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"])

View File

@ -18,6 +18,7 @@ from controllers.web.error import (
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from controllers.web.wraps import WebApiResource
from core.app.app_config.entities import VariableEntity
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -175,11 +176,30 @@ class ChatMCPApi(Resource):
app = db.session.query(App).filter(App.id == server.app_id).first()
if not app:
raise NotFound("App Not Found")
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
try:
user_input_form = [VariableEntity.model_validate(item) for item in user_input_form]
except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}")
try:
request = ClientRequest.model_validate(args)
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request)
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
return helper.compact_generate_response(mcp_server_handler.handle())

View File

@ -108,3 +108,86 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
client_kwargs = {
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
"follow_redirects": True, # Enable redirect following for MCP connections
}
if dify_config.SSRF_PROXY_ALL_URL:
client_kwargs["proxy"] = dify_config.SSRF_PROXY_ALL_URL
return httpx.Client(**client_kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
client_kwargs["mounts"] = proxy_mounts
return httpx.Client(**client_kwargs)
else:
return httpx.Client(**client_kwargs)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception as e:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise

View File

@ -59,7 +59,6 @@ def start_authorization(
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
redirect_url: str,
scope: Optional[str] = None,
) -> tuple[str, str]:
"""Begins the authorization flow."""
response_type = "code"
@ -85,11 +84,9 @@ def start_authorization(
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": "/tools?provider_id=" + client_information.client_id,
}
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
@ -187,7 +184,6 @@ def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server."""
metadata = discover_oauth_metadata(server_url)
@ -233,7 +229,6 @@ def auth(
metadata,
client_information,
provider.redirect_url,
scope or provider.client_metadata.scope,
)
provider.save_code_verifier(code_verifier)

View File

@ -1,6 +1,6 @@
from typing import Optional
from configs.app_config import DifyConfig
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
dify_config = DifyConfig()
class OAuthClientProvider:
provider_id: str
@ -25,7 +23,7 @@ class OAuthClientProvider:
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL
return dify_config.CONSOLE_WEB_URL + "/tools"
@property
def client_metadata(self) -> OAuthClientMetadata:
@ -37,7 +35,6 @@ class OAuthClientProvider:
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
scope="read write",
)
def client_information(self) -> Optional[OAuthClientInformation]:
@ -91,7 +88,3 @@ class OAuthClientProvider:
if not mcp_provider:
return ""
return mcp_provider.credentials.get("code_verifier", "")
class UnauthorizedError(Exception):
pass

View File

@ -1,6 +1,5 @@
import logging
import queue
import threading
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -8,15 +7,21 @@ from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import connect_sse
from sseclient import SSEClient
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage
from core.mcp.utils import create_mcp_http_client, remove_request_params
logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@contextmanager
def sse_client(
@ -36,65 +41,60 @@ def sse_client(
read_queue = queue.Queue()
write_queue = queue.Queue()
status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_mcp_http_client(headers=headers) as client:
with connect_sse(
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
with ssrf_proxy_sse_connect(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue):
try:
while not cancel_event.is_set():
for sse in event_source.iter_sse():
if cancel_event.is_set():
break
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
for sse in event_source.iter_sse():
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
if not cancel_event.is_set():
logger.exception("Error reading SSE messages")
read_queue.put(exc)
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while not cancel_event.is_set():
while True:
try:
message = write_queue.get(timeout=5)
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
response = client.post(
@ -107,14 +107,13 @@ def sse_client(
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty:
if cancel_event.is_set():
break
continue
except Exception:
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
@ -125,11 +124,16 @@ def sse_client(
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error":
raise endpoint_url
executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue
finally:
cancel_event.set()
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception as exc:
logger.exception("Error connecting to SSE endpoint")
raise exc

View File

@ -8,7 +8,6 @@ and session management.
import logging
import queue
import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -17,8 +16,9 @@ from datetime import timedelta
from typing import Any, cast
import httpx
from httpx_sse import EventSource, ServerSentEvent, connect_sse
from httpx_sse import EventSource, ServerSentEvent
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp.types import (
ClientMessageMetadata,
ErrorData,
@ -30,7 +30,6 @@ from core.mcp.types import (
RequestId,
SessionMessage,
)
from core.mcp.utils import create_mcp_http_client
logger = logging.getLogger(__name__)
@ -50,6 +49,8 @@ ACCEPT = "Accept"
JSON = "application/json"
SSE = "text/event-stream"
DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
@ -104,11 +105,6 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON,
**self.headers,
}
self.stop_event = threading.Event()
def stop(self):
"""Signal to stop all operations."""
self.stop_event.set()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
@ -168,6 +164,9 @@ class StreamableHTTPTransport:
# Put exception in queue that goes to client
server_to_client_queue.put(exc)
return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else:
logger.warning(f"Unknown SSE event: {sse.event}")
return False
@ -184,19 +183,18 @@ class StreamableHTTPTransport:
headers = self._update_headers_with_session(self.request_headers)
with connect_sse(
client,
"GET",
with ssrf_proxy_sse_connect(
self.url,
2,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc:
@ -215,19 +213,18 @@ class StreamableHTTPTransport:
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id
with connect_sse(
ctx.client,
"GET",
with ssrf_proxy_sse_connect(
self.url,
2,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
@ -296,15 +293,14 @@ class StreamableHTTPTransport:
try:
event_source = EventSource(response)
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
@ -343,10 +339,10 @@ class StreamableHTTPTransport:
This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue.
"""
while not self.stop_event.is_set():
while True:
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=5)
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
break
@ -379,10 +375,8 @@ class StreamableHTTPTransport:
else:
self._handle_post_request(ctx)
except queue.Empty:
# Timeout - continue loop to check stop_event
continue
except Exception as exc:
# Send exception to client
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None:
@ -444,7 +438,7 @@ def streamablehttp_client(
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
with create_mcp_http_client(
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client:
@ -475,9 +469,6 @@ def streamablehttp_client(
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clean up
transport.stop()
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():

View File

@ -1,6 +1,8 @@
import logging
from collections.abc import Callable
from contextlib import ExitStack
from typing import Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
@ -19,7 +21,6 @@ class MCPClient:
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
):
# Initialize info
self.provider_id = provider_id
@ -30,7 +31,6 @@ class MCPClient:
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
self.scope = scope
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
@ -47,7 +47,7 @@ class MCPClient:
self._initialized = False
def __enter__(self):
self._initialize(first_try=True)
self._initialize()
self._initialized = True
return self
@ -56,44 +56,54 @@ class MCPClient:
def _initialize(
self,
first_try: bool = True,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)]
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(sse_client, "sse")
except MCPConnectionError:
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth
for method_name, client_factory in connection_methods:
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "streamablehttp_client":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
session = cast(ClientSession, self._session)
session.initialize()
return
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
session = cast(ClientSession, self._session)
session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
if first_try:
return self._initialize(first_try=False)
except MCPConnectionError:
if method_name == "streamablehttp_client":
continue
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code)
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
@ -121,5 +131,6 @@ class MCPClient:
self._session = None
self._initialized = False
self.exit_stack.close()
except Exception:
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,30 +2,31 @@ import json
from collections.abc import Mapping
from typing import cast
from configs.app_config import DifyConfig
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, EndUser
from models.model import App, AppMCPServer, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
dify_config = DifyConfig()
class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest):
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
self.app = app
self.request = request
if not self.app.mcp_server:
self.mcp_server: AppMCPServer = self.app.mcp_server
if not self.mcp_server:
raise ValueError("MCP server not found")
self.mcp_server = self.app.mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@property
def request_type(self):
@ -33,6 +34,7 @@ class MCPServerReuqestHandler:
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
return {
"type": "object",
"properties": {
@ -41,10 +43,11 @@ class MCPServerReuqestHandler:
"type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {},
# TODO: add input parameters
"properties": parameters,
"required": required,
},
},
"required": ["query"],
"required": "query",
}
@property
@ -152,3 +155,25 @@ class MCPServerReuqestHandler:
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"]
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
return parameters, required

View File

@ -1,7 +1,5 @@
import logging
import queue
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
@ -40,7 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotif
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 5
DEFAULT_RESPONSE_READ_TIMEOUT = 1
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
self._cancel_event = threading.Event()
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_event = threading.Event()
self._cancel_event.clear()
return self
def __exit__(
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_event:
raise RuntimeError("No active cancel scope")
self._cancel_event.set()
def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
@ -117,17 +109,15 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
self._completed = True
self._session._send_response(request_id=self.request_id, response=response)
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
self._cancel_event.set()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_event.is_set()
class BaseSession(
Generic[
@ -184,11 +166,9 @@ class BaseSession(
self._in_flight = {}
self._exit_stack = ExitStack()
self._futures = []
self._request_id_lock = threading.Lock()
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
self._stop_event = threading.Event()
self._receiver_future = self._executor.submit(self._receive_loop)
return self
@ -196,21 +176,8 @@ class BaseSession(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self._exit_stack.close()
self._stop_event.set()
self._wait_for_futures(timeout=5)
def _wait_for_futures(self, timeout=None):
end_time = time.time() + timeout if timeout else None
for future in list(self._futures):
try:
remaining = end_time - time.time() if end_time else None
if remaining is not None and remaining <= 0:
break
future.result(timeout=remaining)
except Exception as e:
print(f"Error waiting for task: {e}")
self._read_stream.put(None)
self._write_stream.put(None)
def send_request(
self,
@ -247,8 +214,12 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
response_or_error = response_queue.get(timeout=timeout)
while True:
try:
response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
continue
if response_or_error is None:
raise MCPConnectionError(
@ -312,10 +283,10 @@ class BaseSession(
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
"""
while not self._stop_event.is_set():
while True:
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=5)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, HTTPStatusError):
@ -374,12 +345,9 @@ class BaseSession(
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
except queue.Empty:
if self._stop_event.is_set():
break
continue
except Exception as e:
logging.exception("Error in message processing loop")
self._stop_event.set()
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""

View File

@ -3,12 +3,11 @@ from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from configs.app_config import DifyConfig
from configs import dify_config
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
dify_config = DifyConfig()
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)

View File

@ -1177,7 +1177,6 @@ class SessionMessage:
class OAuthClientMetadata(BaseModel):
client_name: str
redirect_uris: list[str]
scope: str
grant_types: Optional[list[str]] = None
response_types: Optional[list[str]] = None
token_endpoint_auth_method: Optional[str] = None

View File

@ -1,28 +0,0 @@
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
kwargs: dict[str, Any] = {
"follow_redirects": True,
}
# Handle timeout
if timeout is None:
kwargs["timeout"] = httpx.Timeout(30.0)
else:
kwargs["timeout"] = timeout
# Handle headers
if headers is not None:
kwargs["headers"] = headers
return httpx.Client(**kwargs)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Literal, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
return {
"id": self.id,
"author": self.author,
@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel):
"allow_delete": self.allow_delete,
"tools": tools,
"labels": self.labels,
**optional_fields,
}
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}

View File

@ -1,6 +1,7 @@
from collections.abc import Generator
from typing import Any, Optional
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -37,9 +38,14 @@ class MCPTool(Tool):
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
for content in result.content:
if isinstance(content, TextContent):
yield self.create_text_message(content.text)

View File

@ -1471,6 +1471,10 @@ class AppMCPServer(Base):
return result
@property
def parameters_dict(self) -> dict[str, Any]:
return json.loads(self.parameters)
class Site(Base):
__tablename__ = "sites"

View File

@ -1,5 +1,6 @@
import json
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
@ -58,8 +59,13 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
try:
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
db.session.commit()