feat: mcp client init

This commit is contained in:
Novice 2025-05-19 18:03:40 +08:00
parent 8de24bc16e
commit c1a58ac160
25 changed files with 6161 additions and 2123 deletions

View File

@ -9,12 +9,17 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.mcp.auth.auth_flow import auth
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
@ -34,7 +39,7 @@ class ToolProviderListApi(Resource):
req.add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow"],
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
@ -613,6 +618,153 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
args = parser.parse_args()
user = current_user
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
)
)
@setup_required
@login_required
@account_initialization_required
def put(self):
parser = reqparse.RequestParser()
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=True, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return jsonable_encoder(
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
name=args["name"],
server_url=args["server_url"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
provider_id=args["provider_id"],
encrypted_credentials={},
)
)
@setup_required
@login_required
@account_initialization_required
def delete(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return jsonable_encoder(
MCPToolManageService.delete_mcp_tool(
tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"]
)
)
class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
try:
with MCPClient(
provider.server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
):
return {"result": "success"}
except MCPAuthError as e:
auth_provider = OAuthClientProvider(provider_id, tenant_id)
return auth(auth_provider, provider.server_url, args["authorization_code"])
class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
return jsonable_encoder(
MCPToolManageService.retrieve_mcp_provider(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
)
class ToolMCPListAllApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
return jsonable_encoder(MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id))
class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
class ToolMCPTokenApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("server_url", type=str, required=True, nullable=False, location="args")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="args")
args = parser.parse_args()
return auth(
OAuthClientProvider(args["provider_id"], current_user.current_tenant_id),
server_url=args["server_url"],
authorization_code=args["authorization_code"],
)
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
@ -647,8 +799,15 @@ api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provid
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPTokenApi, "/workspaces/current/tool-provider/mcp/token")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

0
api/core/mcp/__init__.py Normal file
View File

View File

@ -0,0 +1,240 @@
import base64
import hashlib
import os
import urllib.parse
from typing import Optional
from urllib.parse import urljoin
import requests
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
LATEST_PROTOCOL_VERSION = "1.0"
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = requests.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.ok:
raise Exception(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except requests.RequestException as e:
if isinstance(e, requests.ConnectionError):
response = requests.get(url)
if response.status_code == 404:
return None
if not response.ok:
raise Exception(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
raise
def start_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
redirect_url: str,
scope: Optional[str] = None,
) -> tuple[str, str]:
"""Begins the authorization flow."""
response_type = "code"
code_challenge_method = "S256"
if metadata:
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise Exception(f"Incompatible auth server: does not support response type {response_type}")
if (
not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported
):
raise Exception(f"Incompatible auth server: does not support code challenge method {code_challenge_method}")
else:
authorization_url = urljoin(server_url, "/authorize")
code_verifier, code_challenge = generate_pkce_challenge()
params = {
"response_type": response_type,
"client_id": client_information.client_id,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
}
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def exchange_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
authorization_code: str,
code_verifier: str,
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise Exception(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"code": authorization_code,
"code_verifier": code_verifier,
"redirect_uri": redirect_uri,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
raise Exception(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
def refresh_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise Exception(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"refresh_token": refresh_token,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
raise Exception(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.parse_obj(response.json())
def register_client(
server_url: str,
metadata: Optional[OAuthMetadata],
client_metadata: OAuthClientMetadata,
) -> OAuthClientInformationFull:
"""Performs OAuth 2.0 Dynamic Client Registration."""
if metadata:
if not metadata.registration_endpoint:
raise Exception("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
else:
registration_url = urljoin(server_url, "/register")
response = requests.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
if not response.ok:
response.raise_for_status()
return OAuthClientInformationFull.model_validate(response.json())
def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server."""
metadata = discover_oauth_metadata(server_url)
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise Exception("Existing OAuth client information is required when exchanging an authorization code")
full_information = register_client(server_url, metadata, provider.client_metadata)
provider.save_client_information(full_information)
client_information = full_information
# Exchange authorization code for tokens
if authorization_code is not None:
code_verifier = provider.code_verifier()
tokens = exchange_authorization(
server_url,
metadata,
client_information,
authorization_code,
code_verifier,
provider.redirect_url,
)
provider.save_tokens(tokens)
return {"result": "success"}
tokens = provider.tokens()
# Handle token refresh or new authorization
if tokens and tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:
print(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
client_information,
provider.redirect_url,
scope or provider.client_metadata.scope,
)
provider.save_code_verifier(code_verifier)
return {"authorization_url": authorization_url}

View File

@ -0,0 +1,97 @@
from typing import Optional
from configs.app_config import DifyConfig
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
dify_config = DifyConfig()
class OAuthClientProvider:
provider_id: str
tenant_id: str
def __init__(self, provider_id: str, tenant_id: str):
self.provider_id = provider_id
self.tenant_id = tenant_id
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
scope="read write",
)
def client_information(self) -> Optional[OAuthClientInformation]:
"""Loads information about this OAuth client."""
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
client_information = mcp_provider.credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.tenant_id, self.provider_id, {"client_information": client_information.model_dump()}
)
def tokens(self) -> Optional[OAuthTokens]:
"""Loads any existing OAuth tokens for the current session."""
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
credentials = mcp_provider.credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=credentials.get("expires_in", 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens) -> None:
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.tenant_id, self.provider_id, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str) -> None:
"""Saves a PKCE code verifier for the current session."""
# update mcp provider credentials
MCPToolManageService.update_mcp_provider_credentials(
self.tenant_id, self.provider_id, {"code_verifier": code_verifier}
)
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return ""
return mcp_provider.credentials.get("code_verifier", "")
class UnauthorizedError(Exception):
pass

View File

@ -0,0 +1,192 @@
import logging
import queue
import threading
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import connect_sse
from sseclient import SSEClient
from core.mcp import types
from core.mcp.types import SessionMessage
from core.mcp.utils import create_mcp_http_client, remove_request_params
logger = logging.getLogger(__name__)
@contextmanager
def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> Generator[tuple[queue.Queue, queue.Queue], None, None]:
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
if headers is None:
headers = {}
read_queue = queue.Queue()
write_queue = queue.Queue()
status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_mcp_http_client(headers=headers) as client:
with connect_sse(
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue):
try:
while not cancel_event.is_set():
for sse in event_source.iter_sse():
if cancel_event.is_set():
break
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except Exception as exc:
if not cancel_event.is_set():
logger.exception("Error reading SSE messages")
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while not cancel_event.is_set():
try:
message = write_queue.get(timeout=5)
if message is None:
break
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty:
if cancel_event.is_set():
break
continue
except Exception:
logger.exception("Error writing messages")
finally:
write_queue.put(None)
executor.submit(sse_reader, status_queue)
try:
status, endpoint_url = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue
finally:
cancel_event.set()
except Exception as exc:
logger.exception("Error connecting to SSE endpoint")
raise exc
finally:
read_queue.put(None)
write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
"""
Send a message to the server using the provided HTTP client.
Args:
http_client: The HTTP client to use for sending
endpoint_url: The endpoint URL to send the message to
session_message: The message to send
"""
try:
response = http_client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception as exc:
logger.exception("Error sending message")
raise
def read_messages(
sse_client: SSEClient,
) -> Generator[SessionMessage | Exception, None, None]:
"""
Read messages from the SSE client.
Args:
sse_client: The SSE client to read from
Yields:
SessionMessage or Exception for each event received
"""
try:
for sse in sse_client.events():
if sse.event == "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
yield SessionMessage(message)
except Exception as exc:
logger.exception("Error parsing server message")
yield exc
else:
logger.warning(f"Unknown SSE event: {sse.event}")
except Exception as exc:
logger.exception("Error reading SSE messages")
yield exc

View File

@ -0,0 +1,489 @@
"""
StreamableHTTP Client Transport Module
This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
import logging
import queue
import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, cast
import httpx
from httpx_sse import EventSource, ServerSentEvent, connect_sse
from core.mcp.types import (
ClientMessageMetadata,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
SessionMessage,
)
from core.mcp.utils import create_mcp_http_client
logger = logging.getLogger(__name__)
SessionMessageOrError = SessionMessage | Exception | None
# Queue types with clearer names for their roles
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
JSON = "application/json"
SSE = "text/event-stream"
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.Client
headers: dict[str, str]
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: timedelta
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
) -> None:
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}
self.stop_event = threading.Event()
def stop(self):
"""Signal to stop all operations."""
self.stop_event.set()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
def _handle_sse_event(
self,
sse: ServerSentEvent,
server_to_client_queue: ServerToClientQueue,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], None] | None = None,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
session_message = SessionMessage(message)
# Put message in queue that goes to client
server_to_client_queue.put(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc:
# Put exception in queue that goes to client
server_to_client_queue.put(exc)
return False
else:
logger.warning(f"Unknown SSE event: {sse.event}")
return False
def handle_get_stream(
self,
client: httpx.Client,
server_to_client_queue: ServerToClientQueue,
) -> None:
"""Handle GET stream for server-initiated messages."""
try:
if not self.session_id:
return
headers = self._update_headers_with_session(self.request_headers)
with connect_sse(
client,
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc:
logger.debug(f"GET stream error (non-fatal): {exc}")
def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token")
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id
with connect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
)
return
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,
response: httpx.Response,
server_to_client_queue: ServerToClientQueue,
) -> None:
"""Handle JSON response from the server."""
try:
content = response.read()
message = JSONRPCMessage.model_validate_json(content)
session_message = SessionMessage(message)
server_to_client_queue.put(session_message)
except Exception as exc:
server_to_client_queue.put(exc)
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
except Exception as e:
logger.exception("Error reading SSE stream:")
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
self,
content_type: str,
server_to_client_queue: ServerToClientQueue,
) -> None:
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg)
server_to_client_queue.put(ValueError(error_msg))
def _send_session_terminated_error(
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)
def post_writer(
self,
client: httpx.Client,
client_to_server_queue: ClientToServerQueue,
server_to_client_queue: ServerToClientQueue,
start_get_stream: Callable[[], None],
) -> None:
"""Handle writing requests to the server.
This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue.
"""
while not self.stop_event.is_set():
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=5)
if session_message is None:
break
message = session_message.message
metadata = (
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug(f"Sending client message: {message}")
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
headers=self.request_headers,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
sse_read_timeout=self.sse_read_timeout,
)
if is_resumption:
self._handle_resumption_request(ctx)
else:
self._handle_post_request(ctx)
except queue.Empty:
# Timeout - continue loop to check stop_event
continue
except Exception as exc:
# Send exception to client
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None:
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._update_headers_with_session(self.request_headers)
response = client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code != 200:
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@contextmanager
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True,
) -> Generator[
tuple[
ServerToClientQueue, # Queue for receiving messages FROM server
ClientToServerQueue, # Queue for sending messages TO server
GetSessionIdCallback,
],
None,
None,
]:
"""
Client transport for StreamableHTTP.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Yields:
Tuple containing:
- server_to_client_queue: Queue for reading messages FROM the server
- client_to_server_queue: Queue for sending messages TO the server
- get_session_id_callback: Function to retrieve the current session ID
"""
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
# Create queues with clear directional meaning
server_to_client_queue = queue.Queue() # For messages FROM server TO client
client_to_server_queue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
with create_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clean up
transport.stop()
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)

19
api/core/mcp/entities.py Normal file
View File

@ -0,0 +1,19 @@
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION)
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT

10
api/core/mcp/error.py Normal file
View File

@ -0,0 +1,10 @@
class MCPError(Exception):
pass
class MCPConnectionError(MCPError):
pass
class MCPAuthError(MCPConnectionError):
pass

125
api/core/mcp/mcp_client.py Normal file
View File

@ -0,0 +1,125 @@
import logging
from contextlib import ExitStack
from typing import Optional, cast
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
logger = logging.getLogger(__name__)
class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
self.scope = scope
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: Optional[ClientSession] = None
self._streams_context = None
self._session_context = None
self.exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
self._initialize(first_try=True)
self._initialized = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.cleanup()
def _initialize(
self,
first_try: bool = True,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = [("streamablehttp_client", streamablehttp_client), ("sse_client", sse_client)]
from core.mcp.auth.auth_flow import auth
for method_name, client_factory in connection_methods:
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "streamablehttp_client":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
session = cast(ClientSession, self._session)
session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
if first_try:
return self._initialize(first_try=False)
except MCPConnectionError:
if method_name == "streamablehttp_client":
continue
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
def invoke_tool(self, tool_name: str, tool_args: dict):
"""Call a tool"""
if not self._initialized or not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
def cleanup(self):
"""Clean up resources"""
try:
if self._session:
self._session.__exit__(None, None, None)
if self._streams_context:
self._streams_context.__exit__(None, None, None)
self._session = None
self._initialized = False
self.exit_stack.close()
except Exception:
logging.exception("Error during cleanup")

View File

@ -0,0 +1,415 @@
import logging
import queue
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
from httpx import HTTPStatusError
from pydantic import BaseModel
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageMetadata,
RequestId,
RequestParams,
ServerMessageMetadata,
ServerNotification,
ServerRequest,
ServerResult,
SessionMessage,
)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 5
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
self._cancel_event = threading.Event()
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_event = threading.Event()
self._cancel_event.clear()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_event:
raise RuntimeError("No active cancel scope")
self._cancel_event.set()
def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
self._cancel_event.set()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_event.is_set()
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
This class is a context manager that automatically starts processing
messages when entered.
"""
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
def __init__(
self,
read_stream: queue.Queue,
write_stream: queue.Queue,
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
self._futures = []
self._request_id_lock = threading.Lock()
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
self._stop_event = threading.Event()
self._receiver_future = self._executor.submit(self._receive_loop)
return self
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self._exit_stack.close()
self._stop_event.set()
self._wait_for_futures(timeout=5)
def _wait_for_futures(self, timeout=None):
end_time = time.time() + timeout if timeout else None
for future in list(self._futures):
try:
remaining = end_time - time.time() if end_time else None
if remaining is not None and remaining <= 0:
break
future.result(timeout=remaining)
except Exception as e:
print(f"Error waiting for task: {e}")
def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
will take precedence over the session read timeout.
Do not use this method to emit notifications! Use send_notification()
instead.
"""
request_id = self._request_id
self._request_id = request_id + 1
response_queue = queue.Queue()
self._response_streams[request_id] = response_queue
try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
if request_read_timeout_seconds is not None:
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
response_or_error = response_queue.get(timeout=timeout)
if response_or_error is None:
raise MCPConnectionError(
ErrorData(
code=500,
message="No response received",
)
)
elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401:
raise MCPAuthError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else:
return result_type.model_validate(response_or_error.result)
finally:
self._response_streams.pop(request_id, None)
def send_notification(
self,
notification: SendNotificationT,
related_request_id: RequestId | None = None,
) -> None:
"""
Emits a notification, which is a one-way message that does not expect
a response.
"""
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
self._write_stream.put(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
self._write_stream.put(session_message)
def _receive_loop(self) -> None:
"""
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
"""
while not self._stop_event.is_set():
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=5)
if message is None:
break
if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
)
)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
elif isinstance(message, Exception):
self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
self._received_request(responder)
if not responder._completed:
self._handle_incoming(responder)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification)
self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
else: # Response or error
response_queue = self._response_streams.get(message.message.root.id)
if response_queue is not None:
response_queue.put(message.message.root)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
except queue.Empty:
if self._stop_event.is_set():
break
continue
except Exception as e:
logging.exception("Error in message processing loop")
self._stop_event.set()
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
pass
def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
pass
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""
Sends a progress notification for a request that is currently being
processed.
"""
pass
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@ -0,0 +1,364 @@
from datetime import timedelta
from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
class SamplingFnT(Protocol):
def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...
class ListRootsFnT(Protocol):
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
class LoggingFnT(Protocol):
def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ...
class MessageHandlerFnT(Protocol):
def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ...
def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
elif isinstance(message, (types.ServerNotification | RequestResponder)):
pass
def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream,
write_stream,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
result = self.send_request(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
experimental=None,
roots=roots,
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self.send_notification(
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
)
return result
def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return self.send_request(
types.ClientRequest(
types.PingRequest(
method="ping",
)
),
types.EmptyResult,
)
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
self.send_notification(
types.ClientNotification(
types.ProgressNotification(
method="notifications/progress",
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
),
),
)
)
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return self.send_request(
types.ClientRequest(
types.SetLevelRequest(
method="logging/setLevel",
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request."""
return self.send_request(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
)
),
types.ListResourcesResult,
)
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request."""
return self.send_request(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
)
),
types.ListResourceTemplatesResult,
)
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
method="resources/read",
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return self.send_request(
types.ClientRequest(
types.SubscribeRequest(
method="resources/subscribe",
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return self.send_request(
types.ClientRequest(
types.UnsubscribeRequest(
method="resources/unsubscribe",
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""
return self.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name=name, arguments=arguments),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
)
def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request."""
return self.send_request(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
)
),
types.ListPromptsResult,
)
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return self.send_request(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
def complete(
self,
ref: types.ResourceReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""
return self.send_request(
types.ClientRequest(
types.CompleteRequest(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
),
)
),
types.CompleteResult,
)
def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request."""
return self.send_request(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
)
),
types.ListToolsResult,
)
def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
self.send_notification(
types.ClientNotification(
types.RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
responder.respond(client_response)
case types.PingRequest():
with responder:
return responder.respond(types.ClientResult(root=types.EmptyResult()))
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)
def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
self._logging_callback(params)
case _:
pass

1215
api/core/mcp/types.py Normal file

File diff suppressed because it is too large Load Diff

28
api/core/mcp/utils.py Normal file
View File

@ -0,0 +1,28 @@
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
kwargs: dict[str, Any] = {
"follow_redirects": True,
}
# Handle timeout
if timeout is None:
kwargs["timeout"] = httpx.Timeout(30.0)
else:
kwargs["timeout"] = timeout
# Handle headers
if headers is not None:
kwargs["headers"] = headers
return httpx.Client(**kwargs)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)

View File

@ -1,3 +1,4 @@
from datetime import datetime
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
@ -18,7 +19,7 @@ class ToolApiEntity(BaseModel):
output_schema: Optional[dict] = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
class ToolProviderApiEntity(BaseModel):
@ -37,6 +38,9 @@ class ToolProviderApiEntity(BaseModel):
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
# MCP
server_url: Optional[str] = Field(default="", description="The server url of the tool")
updated_at: datetime = Field(default_factory=datetime.now)
@field_validator("tools", mode="before")
@classmethod

View File

@ -49,6 +49,7 @@ class ToolProviderType(enum.StrEnum):
API = "api"
APP = "app"
DATASET_RETRIEVAL = "dataset-retrieval"
MCP = "mcp"
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":

View File

@ -0,0 +1,130 @@
import json
from typing import Any
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ToolDescription,
ToolEntity,
ToolIdentity,
ToolProviderEntityWithPlugin,
ToolProviderIdentity,
ToolProviderType,
)
from core.tools.mcp_tool.tool import MCPTool
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
class MCPToolProviderController(ToolProviderController):
provider_id: str
entity: ToolProviderEntityWithPlugin
def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
super().__init__(entity)
self.entity = entity
self.tenant_id = tenant_id
self.provider_id = provider_id
self.server_url = server_url
@property
def provider_type(self) -> ToolProviderType:
"""
returns the type of the provider
:return: type of the provider
"""
return ToolProviderType.MCP
@classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
"""
from db provider
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
tools = [
ToolEntity(
identity=ToolIdentity(
author=db_provider.user.name if db_provider.user else "Anonymous",
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.name,
icon=db_provider.icon,
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
description=ToolDescription(
human=I18nObject(
en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
),
llm=remote_mcp_tool.description or "",
),
output_schema=None,
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
)
for remote_mcp_tool in remote_mcp_tools
]
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=db_provider.user.name if db_provider.user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
icon=db_provider.icon,
),
plugin_id=None,
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.id or "",
tenant_id=db_provider.tenant_id or "",
server_url=db_provider.server_url,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
"""
return tool with given name
"""
tool_entity = next(
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
)
if not tool_entity:
raise ValueError(f"Tool with name {tool_name} not found")
return MCPTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
)
def get_tools(self) -> list[MCPTool]: # type: ignore
"""
get all tools
"""
return [
MCPTool(
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
server_url=self.server_url,
provider_id=self.provider_id,
)
for tool_entity in self.entity.tools
]

View File

@ -0,0 +1,57 @@
from collections.abc import Generator
from typing import Any, Optional
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
class MCPTool(Tool):
tenant_id: str
icon: str
runtime_parameters: Optional[list[ToolParameter]]
server_url: str
provider_id: str
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.runtime_parameters = None
self.server_url = server_url
self.provider_id = provider_id
def tool_provider_type(self) -> ToolProviderType:
return ToolProviderType.MCP
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
for content in result.content:
if isinstance(content, TextContent):
yield self.create_text_message(content.text)
elif isinstance(content, ImageContent):
yield self.create_image_message(content.data)
def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
return MCPTool(
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
server_url=self.server_url,
provider_id=self.provider_id,
)

View File

@ -13,9 +13,12 @@ from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
@ -49,7 +52,7 @@ from core.tools.utils.configuration import (
)
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@ -156,7 +159,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@ -292,6 +295,8 @@ class ToolManager:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
elif provider_type == ToolProviderType.MCP:
return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@ -424,6 +429,25 @@ class ToolManager:
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@classmethod
def get_tool_runtime_from_mcp(
cls,
tenant_id: str,
provider_id: str,
tool_name: str,
) -> Tool:
"""
get tool runtime from mcp
"""
return cls.get_tool_runtime(
provider_type=ToolProviderType.MCP,
provider_id=provider_id,
tool_name=tool_name,
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
)
@classmethod
def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
"""
@ -528,7 +552,7 @@ class ToolManager:
yield provider
except Exception:
logger.exception(f"load builtin provider {provider}")
logger.exception(f"load builtin provider {provider_path}")
continue
# set builtin providers loaded
cls._builtin_providers_loaded = True
@ -569,7 +593,7 @@ class ToolManager:
filters = []
if not typ:
filters.extend(["builtin", "api", "workflow"])
filters.extend(["builtin", "api", "workflow", "mcp"])
else:
filters.append(typ)
@ -663,6 +687,10 @@ class ToolManager:
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
if "mcp" in filters:
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id)
for provider in mcp_providers:
result_providers[f"mcp_provider.{provider.name}"] = provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@ -698,6 +726,32 @@ class ToolManager:
return controller, provider.credentials
@classmethod
def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
"""
get the api provider
:param tenant_id: the id of the tenant
:param provider_id: the id of the provider
:return: the provider controller, the credentials
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
.filter(
MCPToolProvider.id == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
if provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
controller = MCPToolProviderController._from_db(provider)
return controller
@classmethod
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
"""
@ -863,6 +917,8 @@ class ToolManager:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
elif provider_type == ToolProviderType.MCP:
return {"background": "#252525", "content": "\ud83d\ude01"}
else:
raise ValueError(f"provider type {provider_type} not found")

View File

@ -6,7 +6,12 @@ from typing import Any, Optional, Union, cast
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import (
ToolEntity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
@ -244,3 +249,84 @@ class WorkflowTool(Tool):
elif transfer_method == FileTransferMethod.LOCAL_FILE:
file_dict["upload_file_id"] = file_dict.get("related_id")
return file_dict
# class MCPTool(Tool):
# """
# MCP tool.
# """
# def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
# super().__init__(entity=entity, runtime=runtime)
# def _invoke(
# self,
# user_id: str,
# tool_parameters: dict[str, Any],
# conversation_id: Optional[str] = None,
# app_id: Optional[str] = None,
# message_id: Optional[str] = None,
# ) -> Generator[ToolInvokeMessage, None, None]:
# """
# invoke the tool
# """
# # Retrieve staff duty schedule
# client = MCPClient()
# res = [
# client.invoke_tool_sync(
# tool_name="NOTION_GET_ABOUT_ME",
# parameters={"params": {}},
# server_url="https://mcp.composio.dev/notion/attractive-fresh-egypt-zkfePp",
# )
# ]
# for r in res:
# for c in r.content:
# if c.type == "text":
# yield self.create_text_message(c.text)
# try:
# yield self.create_json_message(json.loads(c.text))
# except Exception:
# pass
# elif c.type == "json":
# yield self.create_json_message(c.json)
# def _get_tool_runtime(self) -> ToolRuntime:
# """
# get the tool runtime
# """
# return self.runtime
# def tool_provider_type(self) -> ToolProviderType:
# """
# get the tool provider type
# """
# return ToolProviderType.MCP
# @classmethod
# def get_tool_from_runtime(cls, runtime: ToolRuntime) -> "MCPTool":
# """
# get the tool from the runtime
# """
# if runtime.tool_id is None:
# raise ValueError("tool id is required")
# tool_name = MCPToolManageService.get_mcp_tool(runtime.tenant_id, runtime.tool_id)
# if tool_name is None:
# raise ValueError("tool not found")
# entity = ToolEntity(
# identity=ToolIdentity(
# author="dify",
# name=tool_name.name,
# label=I18nObject(
# en_US="MCP",
# zh_Hans="MCP",
# ),
# provider="mcp",
# icon=None,
# ),
# parameters=[],
# description=None,
# output_schema=None,
# has_runtime_parameters=False,
# )
# return cls(entity=entity, runtime=runtime)

View File

@ -0,0 +1,43 @@
"""add mcp provider
Revision ID: 1e67f2654a08
Revises: 6a9f914f656c
Create Date: 2025-05-07 17:40:58.448440
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '1e67f2654a08'
down_revision = '6a9f914f656c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_mcp_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('name', sa.String(length=40), nullable=False),
sa.Column('server_url', sa.String(length=255), nullable=False),
sa.Column('icon', sa.String(length=255), nullable=True),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('user_id', models.types.StringUUID(), nullable=False),
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
sa.Column('authed', sa.Boolean(), nullable=False),
sa.Column('tools', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
sa.UniqueConstraint('name', 'tenant_id', name='unique_mcp_tool_provider')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tool_mcp_providers')
# ### end Alembic commands ###

View File

@ -7,6 +7,7 @@ from deprecated import deprecated
from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column
from core.mcp.types import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@ -193,6 +194,66 @@ class WorkflowToolProvider(Base):
return db.session.query(App).filter(App.id == self.app_id).first()
class MCPToolProvider(Base):
"""
The table stores the mcp providers.
"""
__tablename__ = "tool_mcp_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_mcp_tool_provider"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the mcp provider
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# url of the mcp provider
server_url: Mapped[str] = mapped_column(db.String(255), nullable=False)
# icon of the mcp provider
icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# encrypted credentials
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=False)
# authed
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
# tools
tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
@property
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict:
try:
return cast(dict, json.loads(self.encrypted_credentials)) or {}
except Exception:
return {}
@property
def mcp_tools(self) -> list[Tool]:
return [Tool(**tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> str:
icon_dict = json.loads(self.icon)
return icon_dict
class ToolModelInvoke(Base):
"""
store the invoke logs from tool invoke

View File

@ -84,6 +84,8 @@ dependencies = [
"weave~=0.51.34",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
"sseclient-py>=1.8.0",
"httpx-sse>=0.4.0",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.

View File

@ -0,0 +1,134 @@
import json
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
class MCPToolManageService:
"""
Service class for managing mcp tools.
"""
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider | None:
return (
db.session.query(MCPToolProvider)
.filter(
MCPToolProvider.id == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
@staticmethod
def create_mcp_provider(
tenant_id: str, name: str, server_url: str, user_id: str, icon: str, icon_type: str, icon_background: str
) -> dict:
if (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.name == name)
.first()
):
raise ValueError(f"MCP tool {name} already exists")
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
server_url=server_url,
user_id=user_id,
authed=False,
tools="[]",
icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
)
db.session.add(mcp_tool)
db.session.commit()
return {"result": "success"}
@staticmethod
def retrieve_mcp_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
mcp_providers = db.session.query(MCPToolProvider).filter(MCPToolProvider.tenant_id == tenant_id).all()
return [ToolTransformService.mcp_provider_to_user_provider(mcp_provider) for mcp_provider in mcp_providers]
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
db.session.commit()
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
server_url=mcp_provider.server_url,
updated_at=mcp_provider.updated_at,
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
)
@classmethod
def retrieve_mcp_provider(cls, tenant_id: str, provider_id: str):
provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if provider is None:
raise ValueError("MCP tool not found")
return ToolTransformService.mcp_provider_to_user_provider(provider).to_dict()
@classmethod
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_tool is None:
raise ValueError("MCP tool not found")
db.session.delete(mcp_tool)
db.session.commit()
return {"result": "success"}
@classmethod
def update_mcp_provider(
cls,
tenant_id: str,
provider_id: str,
name: str,
server_url: str,
icon: str,
icon_type: str,
icon_background: str,
encrypted_credentials: dict,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
mcp_provider.name = name
mcp_provider.server_url = server_url
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **encrypted_credentials})
db.session.commit()
return {"result": "success"}
@classmethod
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
mcp_provider.authed = authed
db.session.commit()
return {"result": "success"}
@classmethod
def get_mcp_token(cls, provider_id: str, tenant_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP provider not found")
return mcp_provider.credentials.get("access_token", None)

View File

@ -5,6 +5,7 @@ from typing import Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -21,7 +22,7 @@ from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
logger = logging.getLogger(__name__)
@ -187,6 +188,38 @@ class ToolTransformService:
labels=labels or [],
)
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider) -> ToolProviderApiEntity:
return ToolProviderApiEntity(
id=db_provider.id,
author=db_provider.user.name if db_provider.user else "Anonymous",
name=db_provider.name,
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=db_provider.server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
),
updated_at=db_provider.updated_at,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
)
@staticmethod
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
return [
ToolApiEntity(
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
name=tool.name,
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
labels=[],
)
for tool in tools
]
@classmethod
def api_provider_to_user_provider(
cls,
@ -304,3 +337,59 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
"""
Convert MCP JSON schema to tool parameters
:param schema: JSON schema dictionary
:return: list of ToolParameter instances
"""
def create_parameter(name: str, description: str, param_type: str, required: bool) -> ToolParameter:
"""Create a ToolParameter instance with given attributes"""
return ToolParameter(
name=name,
llm_description=description,
label=I18nObject(en_US=name),
form=ToolParameter.ToolParameterForm.LLM,
required=required,
type=ToolParameter.ToolParameterType(param_type),
human_description=I18nObject(en_US=description),
)
def process_array(name: str, description: str, items: dict, required: bool) -> list[ToolParameter]:
"""Process array type properties"""
item_type = items.get("type", "string")
if item_type == "object" and "properties" in items:
return process_properties(items["properties"], items.get("required", []), f"{name}[0]")
return [create_parameter(name, description, item_type, required)]
def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
"""Process properties recursively"""
parameters = []
for name, prop in props.items():
current_name = f"{prefix}.{name}" if prefix else name
current_description = prop.get("description", "")
prop_type = prop.get("type", "string")
if isinstance(prop_type, list):
prop_type = prop_type[0]
if prop_type == "integer":
prop_type = "number"
if prop_type == "array":
parameters.extend(
process_array(current_name, current_description, prop.get("items", {}), name in required)
)
elif prop_type == "object" and "properties" in prop:
parameters.extend(process_properties(prop["properties"], prop.get("required", []), current_name))
else:
parameters.append(create_parameter(current_name, current_description, prop_type, name in required))
return parameters
if schema.get("type") == "object" and "properties" in schema:
return process_properties(schema["properties"], schema.get("required", []))
return []

4250
api/uv.lock generated

File diff suppressed because it is too large Load Diff