dify/api/core/mcp/mcp_client.py
2025-05-23 18:12:47 +08:00

139 lines
4.7 KiB
Python

import logging
from collections.abc import Callable
from contextlib import ExitStack
from typing import Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
logger = logging.getLogger(__name__)
class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
self.scope = scope
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: Optional[ClientSession] = None
self._streams_context = None
self._session_context = None
self.exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
self._initialize(first_try=True)
self._initialized = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.cleanup()
def _initialize(
self,
first_try: bool = True,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(streamablehttp_client, "sse")
except MCPConnectionError:
self.connect_server(sse_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__()
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
session = cast(ClientSession, self._session)
session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
def invoke_tool(self, tool_name: str, tool_args: dict):
"""Call a tool"""
if not self._initialized or not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
def cleanup(self):
"""Clean up resources"""
try:
if self._session:
self._session.__exit__(None, None, None)
if self._streams_context:
self._streams_context.__exit__(None, None, None)
self._session = None
self._initialized = False
self.exit_stack.close()
except Exception:
logging.exception("Error during cleanup")