mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-07-05 18:05:10 +08:00
384 lines
15 KiB
Python
384 lines
15 KiB
Python
import logging
|
|
import queue
|
|
from collections.abc import Callable
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import ExitStack
|
|
from datetime import timedelta
|
|
from types import TracebackType
|
|
from typing import Any, Generic, Self, TypeVar
|
|
|
|
from httpx import HTTPStatusError
|
|
from pydantic import BaseModel
|
|
|
|
from core.mcp.error import MCPAuthError, MCPConnectionError
|
|
from core.mcp.types import (
|
|
CancelledNotification,
|
|
ClientNotification,
|
|
ClientRequest,
|
|
ClientResult,
|
|
ErrorData,
|
|
JSONRPCError,
|
|
JSONRPCMessage,
|
|
JSONRPCNotification,
|
|
JSONRPCRequest,
|
|
JSONRPCResponse,
|
|
MessageMetadata,
|
|
RequestId,
|
|
RequestParams,
|
|
ServerMessageMetadata,
|
|
ServerNotification,
|
|
ServerRequest,
|
|
ServerResult,
|
|
SessionMessage,
|
|
)
|
|
|
|
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
|
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
|
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
|
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
|
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
|
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
|
DEFAULT_RESPONSE_READ_TIMEOUT = 1
|
|
|
|
|
|
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
|
"""Handles responding to MCP requests and manages request lifecycle.
|
|
|
|
This class MUST be used as a context manager to ensure proper cleanup and
|
|
cancellation handling:
|
|
|
|
Example:
|
|
with request_responder as resp:
|
|
resp.respond(result)
|
|
|
|
The context manager ensures:
|
|
1. Proper cancellation scope setup and cleanup
|
|
2. Request completion tracking
|
|
3. Cleanup of in-flight requests
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
request_id: RequestId,
|
|
request_meta: RequestParams.Meta | None,
|
|
request: ReceiveRequestT,
|
|
session: """BaseSession[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT
|
|
]""",
|
|
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
|
) -> None:
|
|
self.request_id = request_id
|
|
self.request_meta = request_meta
|
|
self.request = request
|
|
self._session = session
|
|
self._completed = False
|
|
self._on_complete = on_complete
|
|
self._entered = False # Track if we're in a context manager
|
|
|
|
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
|
"""Enter the context manager, enabling request cancellation tracking."""
|
|
self._entered = True
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Exit the context manager, performing cleanup and notifying completion."""
|
|
try:
|
|
if self._completed:
|
|
self._on_complete(self)
|
|
finally:
|
|
self._entered = False
|
|
|
|
def respond(self, response: SendResultT | ErrorData) -> None:
|
|
"""Send a response for this request.
|
|
|
|
Must be called within a context manager block.
|
|
Raises:
|
|
RuntimeError: If not used within a context manager
|
|
AssertionError: If request was already responded to
|
|
"""
|
|
if not self._entered:
|
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
|
assert not self._completed, "Request already responded to"
|
|
|
|
self._completed = True
|
|
|
|
self._session._send_response(request_id=self.request_id, response=response)
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel this request and mark it as completed."""
|
|
if not self._entered:
|
|
raise RuntimeError("RequestResponder must be used as a context manager")
|
|
|
|
self._completed = True # Mark as completed so it's removed from in_flight
|
|
# Send an error response to indicate cancellation
|
|
self._session._send_response(
|
|
request_id=self.request_id,
|
|
response=ErrorData(code=0, message="Request cancelled", data=None),
|
|
)
|
|
|
|
|
|
class BaseSession(
|
|
Generic[
|
|
SendRequestT,
|
|
SendNotificationT,
|
|
SendResultT,
|
|
ReceiveRequestT,
|
|
ReceiveNotificationT,
|
|
],
|
|
):
|
|
"""
|
|
Implements an MCP "session" on top of read/write streams, including features
|
|
like request/response linking, notifications, and progress.
|
|
|
|
This class is a context manager that automatically starts processing
|
|
messages when entered.
|
|
"""
|
|
|
|
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
|
_request_id: int
|
|
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
|
|
|
def __init__(
|
|
self,
|
|
read_stream: queue.Queue,
|
|
write_stream: queue.Queue,
|
|
receive_request_type: type[ReceiveRequestT],
|
|
receive_notification_type: type[ReceiveNotificationT],
|
|
# If none, reading will never time out
|
|
read_timeout_seconds: timedelta | None = None,
|
|
) -> None:
|
|
self._read_stream = read_stream
|
|
self._write_stream = write_stream
|
|
self._response_streams = {}
|
|
self._request_id = 0
|
|
self._receive_request_type = receive_request_type
|
|
self._receive_notification_type = receive_notification_type
|
|
self._session_read_timeout_seconds = read_timeout_seconds
|
|
self._in_flight = {}
|
|
self._exit_stack = ExitStack()
|
|
self._futures = []
|
|
|
|
def __enter__(self) -> Self:
|
|
self._executor = ThreadPoolExecutor()
|
|
self._receiver_future = self._executor.submit(self._receive_loop)
|
|
return self
|
|
|
|
def __exit__(
|
|
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
|
) -> None:
|
|
self._exit_stack.close()
|
|
self._read_stream.put(None)
|
|
self._write_stream.put(None)
|
|
|
|
def send_request(
|
|
self,
|
|
request: SendRequestT,
|
|
result_type: type[ReceiveResultT],
|
|
request_read_timeout_seconds: timedelta | None = None,
|
|
metadata: MessageMetadata = None,
|
|
) -> ReceiveResultT:
|
|
"""
|
|
Sends a request and wait for a response. Raises an McpError if the
|
|
response contains an error. If a request read timeout is provided, it
|
|
will take precedence over the session read timeout.
|
|
|
|
Do not use this method to emit notifications! Use send_notification()
|
|
instead.
|
|
"""
|
|
|
|
request_id = self._request_id
|
|
self._request_id = request_id + 1
|
|
|
|
response_queue = queue.Queue()
|
|
self._response_streams[request_id] = response_queue
|
|
|
|
try:
|
|
jsonrpc_request = JSONRPCRequest(
|
|
jsonrpc="2.0",
|
|
id=request_id,
|
|
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
|
)
|
|
|
|
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
|
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
|
|
if request_read_timeout_seconds is not None:
|
|
timeout = request_read_timeout_seconds.total_seconds()
|
|
elif self._session_read_timeout_seconds is not None:
|
|
timeout = self._session_read_timeout_seconds.total_seconds()
|
|
while True:
|
|
try:
|
|
response_or_error = response_queue.get(timeout=timeout)
|
|
break
|
|
except queue.Empty:
|
|
continue
|
|
|
|
if response_or_error is None:
|
|
raise MCPConnectionError(
|
|
ErrorData(
|
|
code=500,
|
|
message="No response received",
|
|
)
|
|
)
|
|
elif isinstance(response_or_error, JSONRPCError):
|
|
if response_or_error.error.code == 401:
|
|
raise MCPAuthError(
|
|
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
|
)
|
|
else:
|
|
raise MCPConnectionError(
|
|
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
|
)
|
|
else:
|
|
return result_type.model_validate(response_or_error.result)
|
|
|
|
finally:
|
|
self._response_streams.pop(request_id, None)
|
|
|
|
def send_notification(
|
|
self,
|
|
notification: SendNotificationT,
|
|
related_request_id: RequestId | None = None,
|
|
) -> None:
|
|
"""
|
|
Emits a notification, which is a one-way message that does not expect
|
|
a response.
|
|
"""
|
|
# Some transport implementations may need to set the related_request_id
|
|
# to attribute to the notifications to the request that triggered them.
|
|
jsonrpc_notification = JSONRPCNotification(
|
|
jsonrpc="2.0",
|
|
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
|
)
|
|
session_message = SessionMessage(
|
|
message=JSONRPCMessage(jsonrpc_notification),
|
|
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
|
)
|
|
self._write_stream.put(session_message)
|
|
|
|
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
|
if isinstance(response, ErrorData):
|
|
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
|
self._write_stream.put(session_message)
|
|
else:
|
|
jsonrpc_response = JSONRPCResponse(
|
|
jsonrpc="2.0",
|
|
id=request_id,
|
|
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
|
)
|
|
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
|
self._write_stream.put(session_message)
|
|
|
|
def _receive_loop(self) -> None:
|
|
"""
|
|
Main message processing loop.
|
|
In a real synchronous implementation, this would likely run in a separate thread.
|
|
"""
|
|
while True:
|
|
try:
|
|
# Attempt to receive a message (this would be blocking in a synchronous context)
|
|
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
|
if message is None:
|
|
break
|
|
if isinstance(message, HTTPStatusError):
|
|
response_queue = self._response_streams.get(self._request_id - 1)
|
|
if response_queue is not None:
|
|
response_queue.put(
|
|
JSONRPCError(
|
|
jsonrpc="2.0",
|
|
id=self._request_id - 1,
|
|
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
|
)
|
|
)
|
|
else:
|
|
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
|
elif isinstance(message, Exception):
|
|
self._handle_incoming(message)
|
|
elif isinstance(message.message.root, JSONRPCRequest):
|
|
validated_request = self._receive_request_type.model_validate(
|
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
|
)
|
|
|
|
responder = RequestResponder(
|
|
request_id=message.message.root.id,
|
|
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
|
request=validated_request,
|
|
session=self,
|
|
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
|
)
|
|
|
|
self._in_flight[responder.request_id] = responder
|
|
self._received_request(responder)
|
|
|
|
if not responder._completed:
|
|
self._handle_incoming(responder)
|
|
|
|
elif isinstance(message.message.root, JSONRPCNotification):
|
|
try:
|
|
notification = self._receive_notification_type.model_validate(
|
|
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
|
)
|
|
# Handle cancellation notifications
|
|
if isinstance(notification.root, CancelledNotification):
|
|
cancelled_id = notification.root.params.requestId
|
|
if cancelled_id in self._in_flight:
|
|
self._in_flight[cancelled_id].cancel()
|
|
else:
|
|
self._received_notification(notification)
|
|
self._handle_incoming(notification)
|
|
except Exception as e:
|
|
# For other validation errors, log and continue
|
|
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
|
|
else: # Response or error
|
|
response_queue = self._response_streams.get(message.message.root.id)
|
|
if response_queue is not None:
|
|
response_queue.put(message.message.root)
|
|
else:
|
|
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
|
except queue.Empty:
|
|
continue
|
|
except Exception as e:
|
|
logging.exception("Error in message processing loop")
|
|
|
|
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
|
"""
|
|
Can be overridden by subclasses to handle a request without needing to
|
|
listen on the message stream.
|
|
|
|
If the request is responded to within this method, it will not be
|
|
forwarded on to the message stream.
|
|
"""
|
|
pass
|
|
|
|
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
|
"""
|
|
Can be overridden by subclasses to handle a notification without needing
|
|
to listen on the message stream.
|
|
"""
|
|
pass
|
|
|
|
def send_progress_notification(
|
|
self, progress_token: str | int, progress: float, total: float | None = None
|
|
) -> None:
|
|
"""
|
|
Sends a progress notification for a request that is currently being
|
|
processed.
|
|
"""
|
|
pass
|
|
|
|
def _handle_incoming(
|
|
self,
|
|
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
|
) -> None:
|
|
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
|
pass
|