mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-05 20:35:10 +08:00
137 lines
4.7 KiB
Python
137 lines
4.7 KiB
Python
import logging
|
|
from collections.abc import Callable
|
|
from contextlib import ExitStack
|
|
from typing import Optional, cast
|
|
from urllib.parse import urlparse
|
|
|
|
from core.mcp.client.sse_client import sse_client
|
|
from core.mcp.client.streamable_client import streamablehttp_client
|
|
from core.mcp.error import MCPAuthError, MCPConnectionError
|
|
from core.mcp.session.client_session import ClientSession
|
|
from core.mcp.types import Tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MCPClient:
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
provider_id: str,
|
|
tenant_id: str,
|
|
authed: bool = True,
|
|
authorization_code: Optional[str] = None,
|
|
):
|
|
# Initialize info
|
|
self.provider_id = provider_id
|
|
self.tenant_id = tenant_id
|
|
self.client_type = "streamable"
|
|
self.server_url = server_url
|
|
|
|
# Authentication info
|
|
self.authed = authed
|
|
self.authorization_code = authorization_code
|
|
if authed:
|
|
from core.mcp.auth.auth_provider import OAuthClientProvider
|
|
|
|
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
|
|
self.token = self.provider.tokens()
|
|
|
|
# Initialize session and client objects
|
|
self._session: Optional[ClientSession] = None
|
|
self._streams_context = None
|
|
self._session_context = None
|
|
self.exit_stack = ExitStack()
|
|
|
|
# Whether the client has been initialized
|
|
self._initialized = False
|
|
|
|
def __enter__(self):
|
|
self._initialize()
|
|
self._initialized = True
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.cleanup()
|
|
|
|
def _initialize(
|
|
self,
|
|
):
|
|
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
|
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
|
|
|
|
parsed_url = urlparse(self.server_url)
|
|
path = parsed_url.path
|
|
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
|
try:
|
|
client_factory = connection_methods[method_name]
|
|
self.connect_server(client_factory, method_name)
|
|
except KeyError:
|
|
try:
|
|
self.connect_server(sse_client, "sse")
|
|
except MCPConnectionError:
|
|
self.connect_server(streamablehttp_client, "mcp")
|
|
|
|
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
|
|
from core.mcp.auth.auth_flow import auth
|
|
|
|
try:
|
|
headers = (
|
|
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
|
if self.authed and self.token
|
|
else {}
|
|
)
|
|
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
|
if method_name == "mcp":
|
|
read_stream, write_stream, _ = self._streams_context.__enter__()
|
|
streams = (read_stream, write_stream)
|
|
else: # sse_client
|
|
streams = self._streams_context.__enter__()
|
|
|
|
self._session_context = ClientSession(*streams)
|
|
self._session = self._session_context.__enter__()
|
|
session = cast(ClientSession, self._session)
|
|
session.initialize()
|
|
return
|
|
|
|
except MCPAuthError:
|
|
if not self.authed:
|
|
raise
|
|
auth(self.provider, self.server_url, self.authorization_code)
|
|
self.token = self.provider.tokens()
|
|
if first_try:
|
|
return self.connect_server(client_factory, method_name, first_try=False)
|
|
|
|
except MCPConnectionError:
|
|
raise
|
|
|
|
def list_tools(self) -> list[Tool]:
|
|
"""Connect to an MCP server running with SSE transport"""
|
|
# List available tools to verify connection
|
|
if not self._initialized or not self._session:
|
|
raise ValueError("Session not initialized.")
|
|
response = self._session.list_tools()
|
|
tools = response.tools
|
|
return tools
|
|
|
|
def invoke_tool(self, tool_name: str, tool_args: dict):
|
|
"""Call a tool"""
|
|
if not self._initialized or not self._session:
|
|
raise ValueError("Session not initialized.")
|
|
return self._session.call_tool(tool_name, tool_args)
|
|
|
|
def cleanup(self):
|
|
"""Clean up resources"""
|
|
try:
|
|
if self._session:
|
|
self._session.__exit__(None, None, None)
|
|
|
|
if self._streams_context:
|
|
self._streams_context.__exit__(None, None, None)
|
|
self._session = None
|
|
self._initialized = False
|
|
self.exit_stack.close()
|
|
except Exception as e:
|
|
logging.exception("Error during cleanup")
|
|
raise ValueError(f"Error during cleanup: {e}")
|