dify/api/core/mcp/client/sse_client.py
2025-05-19 18:03:40 +08:00

193 lines
8.2 KiB
Python

import logging
import queue
import threading
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import httpx
from httpx_sse import connect_sse
from sseclient import SSEClient
from core.mcp import types
from core.mcp.types import SessionMessage
from core.mcp.utils import create_mcp_http_client, remove_request_params
logger = logging.getLogger(__name__)
@contextmanager
def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> Generator[tuple[queue.Queue, queue.Queue], None, None]:
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
if headers is None:
headers = {}
read_queue = queue.Queue()
write_queue = queue.Queue()
status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_mcp_http_client(headers=headers) as client:
with connect_sse(
client, "GET", url, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue):
try:
while not cancel_event.is_set():
for sse in event_source.iter_sse():
if cancel_event.is_set():
break
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except Exception as exc:
if not cancel_event.is_set():
logger.exception("Error reading SSE messages")
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while not cancel_event.is_set():
try:
message = write_queue.get(timeout=5)
if message is None:
break
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty:
if cancel_event.is_set():
break
continue
except Exception:
logger.exception("Error writing messages")
finally:
write_queue.put(None)
executor.submit(sse_reader, status_queue)
try:
status, endpoint_url = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue
finally:
cancel_event.set()
except Exception as exc:
logger.exception("Error connecting to SSE endpoint")
raise exc
finally:
read_queue.put(None)
write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
"""
Send a message to the server using the provided HTTP client.
Args:
http_client: The HTTP client to use for sending
endpoint_url: The endpoint URL to send the message to
session_message: The message to send
"""
try:
response = http_client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception as exc:
logger.exception("Error sending message")
raise
def read_messages(
sse_client: SSEClient,
) -> Generator[SessionMessage | Exception, None, None]:
"""
Read messages from the SSE client.
Args:
sse_client: The SSE client to read from
Yields:
SessionMessage or Exception for each event received
"""
try:
for sse in sse_client.events():
if sse.event == "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
yield SessionMessage(message)
except Exception as exc:
logger.exception("Error parsing server message")
yield exc
else:
logger.warning(f"Unknown SSE event: {sse.event}")
except Exception as exc:
logger.exception("Error reading SSE messages")
yield exc