feat: upgrade streamable http client

This commit is contained in:
Novice 2025-05-27 13:14:51 +08:00
parent 1fd4839eca
commit 41bbcb9540
16 changed files with 167 additions and 155 deletions

View File

@ -68,16 +68,31 @@ class AppMCPServerController(Resource):
parser.add_argument("id", type=str, required=True, location="json") parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json") parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise Forbidden() raise Forbidden()
server.description = args["description"] server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False) server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
server.status = AppMCPServerStatus(args["status"]) db.session.commit()
return server
class AppMCPServerRefreshController(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
def get(self, server_id):
if not current_user.is_editor:
raise Forbidden()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
if not server:
raise Forbidden()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit() db.session.commit()
return server return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server") api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")

View File

@ -18,6 +18,7 @@ from controllers.web.error import (
) )
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_config.entities import VariableEntity
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ( from core.errors.error import (
@ -175,11 +176,30 @@ class ChatMCPApi(Resource):
app = db.session.query(App).filter(App.id == server.app_id).first() app = db.session.query(App).filter(App.id == server.app_id).first()
if not app: if not app:
raise NotFound("App Not Found") raise NotFound("App Not Found")
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
try:
user_input_form = [VariableEntity.model_validate(item) for item in user_input_form]
except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}")
try: try:
request = ClientRequest.model_validate(args) request = ClientRequest.model_validate(args)
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}") raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request) mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
return helper.compact_generate_response(mcp_server_handler.handle()) return helper.compact_generate_response(mcp_server_handler.handle())

View File

@ -127,6 +127,7 @@ def create_ssrf_proxy_mcp_http_client(
"verify": HTTP_REQUEST_NODE_SSL_VERIFY, "verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {}, "headers": headers or {},
"timeout": timeout, "timeout": timeout,
"follow_redirects": True, # Enable redirect following for MCP connections
} }
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:

View File

@ -59,7 +59,6 @@ def start_authorization(
metadata: Optional[OAuthMetadata], metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation, client_information: OAuthClientInformation,
redirect_url: str, redirect_url: str,
scope: Optional[str] = None,
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Begins the authorization flow.""" """Begins the authorization flow."""
response_type = "code" response_type = "code"
@ -85,11 +84,9 @@ def start_authorization(
"code_challenge": code_challenge, "code_challenge": code_challenge,
"code_challenge_method": code_challenge_method, "code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url, "redirect_uri": redirect_url,
"state": "/tools?provider_id=" + client_information.client_id,
} }
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}" authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier return authorization_url, code_verifier
@ -187,7 +184,6 @@ def auth(
provider: OAuthClientProvider, provider: OAuthClientProvider,
server_url: str, server_url: str,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
scope: Optional[str] = None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Orchestrates the full auth flow with a server.""" """Orchestrates the full auth flow with a server."""
metadata = discover_oauth_metadata(server_url) metadata = discover_oauth_metadata(server_url)
@ -233,7 +229,6 @@ def auth(
metadata, metadata,
client_information, client_information,
provider.redirect_url, provider.redirect_url,
scope or provider.client_metadata.scope,
) )
provider.save_code_verifier(code_verifier) provider.save_code_verifier(code_verifier)

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
from configs.app_config import DifyConfig from configs import dify_config
from core.mcp.types import ( from core.mcp.types import (
OAuthClientInformation, OAuthClientInformation,
OAuthClientInformationFull, OAuthClientInformationFull,
@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0" LATEST_PROTOCOL_VERSION = "1.0"
dify_config = DifyConfig()
class OAuthClientProvider: class OAuthClientProvider:
provider_id: str provider_id: str
@ -25,7 +23,7 @@ class OAuthClientProvider:
@property @property
def redirect_url(self) -> str: def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization.""" """The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL return dify_config.CONSOLE_WEB_URL + "/tools"
@property @property
def client_metadata(self) -> OAuthClientMetadata: def client_metadata(self) -> OAuthClientMetadata:
@ -37,7 +35,6 @@ class OAuthClientProvider:
response_types=["code"], response_types=["code"],
client_name="Dify", client_name="Dify",
client_uri="https://github.com/langgenius/dify", client_uri="https://github.com/langgenius/dify",
scope="read write",
) )
def client_information(self) -> Optional[OAuthClientInformation]: def client_information(self) -> Optional[OAuthClientInformation]:
@ -91,7 +88,3 @@ class OAuthClientProvider:
if not mcp_provider: if not mcp_provider:
return "" return ""
return mcp_provider.credentials.get("code_verifier", "") return mcp_provider.credentials.get("code_verifier", "")
class UnauthorizedError(Exception):
pass

View File

@ -1,6 +1,5 @@
import logging import logging
import queue import queue
import threading
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
@ -42,7 +41,7 @@ def sse_client(
read_queue = queue.Queue() read_queue = queue.Queue()
write_queue = queue.Queue() write_queue = queue.Queue()
status_queue = queue.Queue() status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
try: try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
@ -51,54 +50,49 @@ def sse_client(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source: ) as event_source:
event_source.response.raise_for_status() event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue): def sse_reader(status_queue: queue.Queue):
try: try:
while not cancel_event.is_set(): for sse in event_source.iter_sse():
for sse in event_source.iter_sse(): match sse.event:
if cancel_event.is_set(): case "endpoint":
break endpoint_url = urljoin(url, sse.data)
match sse.event: logger.info(f"Received endpoint URL: {endpoint_url}")
case "endpoint": url_parsed = urlparse(url)
endpoint_url = urljoin(url, sse.data) endpoint_parsed = urlparse(endpoint_url)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc: except Exception as exc:
if not cancel_event.is_set(): read_queue.put(exc)
logger.exception("Error reading SSE messages")
read_queue.put(exc)
finally: finally:
read_queue.put(None) read_queue.put(None)
def post_writer(endpoint_url: str): def post_writer(endpoint_url: str):
try: try:
while not cancel_event.is_set(): while True:
try: try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None: if message is None:
@ -113,14 +107,13 @@ def sse_client(
) )
response.raise_for_status() response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}") logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty: except queue.Empty:
if cancel_event.is_set():
break
continue continue
except Exception: except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages") logger.exception("Error writing messages")
write_queue.put(exc)
finally: finally:
write_queue.put(None) write_queue.put(None)
@ -131,11 +124,12 @@ def sse_client(
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if status != "ready": if status != "ready":
raise ValueError("failed to get endpoint URL") raise ValueError("failed to get endpoint URL")
if status == "error":
raise endpoint_url
executor.submit(post_writer, endpoint_url) executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue yield read_queue, write_queue
finally:
cancel_event.set()
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401: if exc.response.status_code == 401:
raise MCPAuthError() raise MCPAuthError()

View File

@ -8,7 +8,6 @@ and session management.
import logging import logging
import queue import queue
import threading
from collections.abc import Callable, Generator from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
@ -106,11 +105,6 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON, CONTENT_TYPE: JSON,
**self.headers, **self.headers,
} }
self.stop_event = threading.Event()
def stop(self):
"""Signal to stop all operations."""
self.stop_event.set()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]: def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available.""" """Update headers with session ID if available."""
@ -170,6 +164,9 @@ class StreamableHTTPTransport:
# Put exception in queue that goes to client # Put exception in queue that goes to client
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
return False return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else: else:
logger.warning(f"Unknown SSE event: {sse.event}") logger.warning(f"Unknown SSE event: {sse.event}")
return False return False
@ -198,8 +195,6 @@ class StreamableHTTPTransport:
logger.debug("GET SSE connection established") logger.debug("GET SSE connection established")
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(sse, server_to_client_queue) self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc: except Exception as exc:
@ -230,8 +225,6 @@ class StreamableHTTPTransport:
logger.debug("Resumption GET SSE connection established") logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
is_complete = self._handle_sse_event( is_complete = self._handle_sse_event(
sse, sse,
ctx.server_to_client_queue, ctx.server_to_client_queue,
@ -300,13 +293,13 @@ class StreamableHTTPTransport:
try: try:
event_source = EventSource(response) event_source = EventSource(response)
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
if self.stop_event.is_set(): is_complete = self._handle_sse_event(
break
self._handle_sse_event(
sse, sse,
ctx.server_to_client_queue, ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
) )
if is_complete:
break
except Exception as e: except Exception as e:
ctx.server_to_client_queue.put(e) ctx.server_to_client_queue.put(e)
@ -346,7 +339,7 @@ class StreamableHTTPTransport:
This method processes messages from the client_to_server_queue and sends them to the server. This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue. Responses are written to the server_to_client_queue.
""" """
while not self.stop_event.is_set(): while True:
try: try:
# Read message from client queue with timeout to check stop_event periodically # Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT) session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
@ -382,10 +375,8 @@ class StreamableHTTPTransport:
else: else:
self._handle_post_request(ctx) self._handle_post_request(ctx)
except queue.Empty: except queue.Empty:
# Timeout - continue loop to check stop_event
continue continue
except Exception as exc: except Exception as exc:
# Send exception to client
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None: def terminate_session(self, client: httpx.Client) -> None:
@ -478,9 +469,6 @@ def streamablehttp_client(
# Signal threads to stop # Signal threads to stop
client_to_server_queue.put(None) client_to_server_queue.put(None)
finally: finally:
# Clean up
transport.stop()
# Clear any remaining items and add None sentinel to unblock any waiting threads # Clear any remaining items and add None sentinel to unblock any waiting threads
try: try:
while not client_to_server_queue.empty(): while not client_to_server_queue.empty():

View File

@ -21,7 +21,6 @@ class MCPClient:
tenant_id: str, tenant_id: str,
authed: bool = True, authed: bool = True,
authorization_code: Optional[str] = None, authorization_code: Optional[str] = None,
scope: Optional[str] = None,
): ):
# Initialize info # Initialize info
self.provider_id = provider_id self.provider_id = provider_id
@ -32,7 +31,6 @@ class MCPClient:
# Authentication info # Authentication info
self.authed = authed self.authed = authed
self.authorization_code = authorization_code self.authorization_code = authorization_code
self.scope = scope
if authed: if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
@ -49,7 +47,7 @@ class MCPClient:
self._initialized = False self._initialized = False
def __enter__(self): def __enter__(self):
self._initialize(first_try=True) self._initialize()
self._initialized = True self._initialized = True
return self return self
@ -58,7 +56,6 @@ class MCPClient:
def _initialize( def _initialize(
self, self,
first_try: bool = True,
): ):
"""Initialize the client with fallback to SSE if streamable connection fails""" """Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client} connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
@ -71,9 +68,9 @@ class MCPClient:
self.connect_server(client_factory, method_name) self.connect_server(client_factory, method_name)
except KeyError: except KeyError:
try: try:
self.connect_server(streamablehttp_client, "sse") self.connect_server(sse_client, "sse")
except MCPConnectionError: except MCPConnectionError:
self.connect_server(sse_client, "mcp") self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True): def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth from core.mcp.auth.auth_flow import auth
@ -100,8 +97,8 @@ class MCPClient:
except MCPAuthError: except MCPAuthError:
if not self.authed: if not self.authed:
raise raise
auth(self.provider, self.server_url, self.authorization_code)
auth(self.provider, self.server_url, self.authorization_code, self.scope) self.token = self.provider.tokens()
if first_try: if first_try:
return self.connect_server(client_factory, method_name, first_try=False) return self.connect_server(client_factory, method_name, first_try=False)
@ -134,5 +131,6 @@ class MCPClient:
self._session = None self._session = None
self._initialized = False self._initialized = False
self.exit_stack.close() self.exit_stack.close()
except Exception: except Exception as e:
logging.exception("Error during cleanup") logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,30 +2,31 @@ import json
from collections.abc import Mapping from collections.abc import Mapping
from typing import cast from typing import cast
from configs.app_config import DifyConfig from configs import dify_config
from controllers.web.passport import generate_session_id from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, EndUser from models.model import App, AppMCPServer, EndUser
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
""" """
Apply to MCP HTTP streamable server with stateless http Apply to MCP HTTP streamable server with stateless http
""" """
dify_config = DifyConfig()
class MCPServerReuqestHandler: class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest): def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
self.app = app self.app = app
self.request = request self.request = request
if not self.app.mcp_server: self.mcp_server: AppMCPServer = self.app.mcp_server
if not self.mcp_server:
raise ValueError("MCP server not found") raise ValueError("MCP server not found")
self.mcp_server = self.app.mcp_server
self.end_user = self.retrieve_end_user() self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@property @property
def request_type(self): def request_type(self):
@ -33,6 +34,7 @@ class MCPServerReuqestHandler:
@property @property
def parameter_schema(self): def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
@ -41,10 +43,11 @@ class MCPServerReuqestHandler:
"type": "object", "type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501 "description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {}, "default": {},
# TODO: add input parameters "properties": parameters,
"required": required,
}, },
}, },
"required": ["query"], "required": "query",
} }
@property @property
@ -152,3 +155,25 @@ class MCPServerReuqestHandler:
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"]
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
return parameters, required

View File

@ -1,7 +1,5 @@
import logging import logging
import queue import queue
import threading
import time
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack from contextlib import ExitStack
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._completed = False self._completed = False
self._on_complete = on_complete self._on_complete = on_complete
self._entered = False # Track if we're in a context manager self._entered = False # Track if we're in a context manager
self._cancel_event = threading.Event()
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking.""" """Enter the context manager, enabling request cancellation tracking."""
self._entered = True self._entered = True
self._cancel_event = threading.Event()
self._cancel_event.clear()
return self return self
def __exit__( def __exit__(
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._on_complete(self) self._on_complete(self)
finally: finally:
self._entered = False self._entered = False
if not self._cancel_event:
raise RuntimeError("No active cancel scope")
self._cancel_event.set()
def respond(self, response: SendResultT | ErrorData) -> None: def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request. """Send a response for this request.
@ -117,17 +109,15 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to" assert not self._completed, "Request already responded to"
if not self.cancelled: self._completed = True
self._completed = True
self._session._send_response(request_id=self.request_id, response=response) self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None: def cancel(self) -> None:
"""Cancel this request and mark it as completed.""" """Cancel this request and mark it as completed."""
if not self._entered: if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
self._cancel_event.set()
self._completed = True # Mark as completed so it's removed from in_flight self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation # Send an error response to indicate cancellation
self._session._send_response( self._session._send_response(
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
response=ErrorData(code=0, message="Request cancelled", data=None), response=ErrorData(code=0, message="Request cancelled", data=None),
) )
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_event.is_set()
class BaseSession( class BaseSession(
Generic[ Generic[
@ -184,11 +166,9 @@ class BaseSession(
self._in_flight = {} self._in_flight = {}
self._exit_stack = ExitStack() self._exit_stack = ExitStack()
self._futures = [] self._futures = []
self._request_id_lock = threading.Lock()
def __enter__(self) -> Self: def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor() self._executor = ThreadPoolExecutor()
self._stop_event = threading.Event()
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
@ -196,21 +176,8 @@ class BaseSession(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ) -> None:
self._exit_stack.close() self._exit_stack.close()
self._stop_event.set() self._read_stream.put(None)
self._wait_for_futures(timeout=5) self._write_stream.put(None)
def _wait_for_futures(self, timeout=None):
end_time = time.time() + timeout if timeout else None
for future in list(self._futures):
try:
remaining = end_time - time.time() if end_time else None
if remaining is not None and remaining <= 0:
break
future.result(timeout=remaining)
except Exception as e:
logging.exception(f"Error waiting for task: {e}")
def send_request( def send_request(
self, self,
@ -247,7 +214,7 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds() timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None: elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds() timeout = self._session_read_timeout_seconds.total_seconds()
while not self._stop_event.is_set(): while True:
try: try:
response_or_error = response_queue.get(timeout=timeout) response_or_error = response_queue.get(timeout=timeout)
break break
@ -316,7 +283,7 @@ class BaseSession(
Main message processing loop. Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread. In a real synchronous implementation, this would likely run in a separate thread.
""" """
while not self._stop_event.is_set(): while True:
try: try:
# Attempt to receive a message (this would be blocking in a synchronous context) # Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT) message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
@ -378,12 +345,9 @@ class BaseSession(
else: else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
except queue.Empty: except queue.Empty:
if self._stop_event.is_set():
break
continue continue
except Exception as e: except Exception as e:
logging.exception("Error in message processing loop") logging.exception("Error in message processing loop")
self._stop_event.set()
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
""" """

View File

@ -3,12 +3,11 @@ from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter from pydantic import AnyUrl, TypeAdapter
from configs.app_config import DifyConfig from configs import dify_config
from core.mcp import types from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder from core.mcp.session.base_session import BaseSession, RequestResponder
dify_config = DifyConfig()
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION) DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)

View File

@ -1177,7 +1177,6 @@ class SessionMessage:
class OAuthClientMetadata(BaseModel): class OAuthClientMetadata(BaseModel):
client_name: str client_name: str
redirect_uris: list[str] redirect_uris: list[str]
scope: str
grant_types: Optional[list[str]] = None grant_types: Optional[list[str]] = None
response_types: Optional[list[str]] = None response_types: Optional[list[str]] = None
token_endpoint_auth_method: Optional[str] = None token_endpoint_auth_method: Optional[str] = None

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Literal, Optional from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value: if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files" parameter["type"] = "files"
# ------------- # -------------
optional_fields = self.optional_field("server_url", self.server_url)
return { return {
"id": self.id, "id": self.id,
"author": self.author, "author": self.author,
@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel):
"allow_delete": self.allow_delete, "allow_delete": self.allow_delete,
"tools": tools, "tools": tools,
"labels": self.labels, "labels": self.labels,
**optional_fields,
} }
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}

View File

@ -1,6 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent from core.mcp.types import ImageContent, TextContent
from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -37,9 +38,14 @@ class MCPTool(Tool):
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client: try:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters) with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters) tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
for content in result.content: for content in result.content:
if isinstance(content, TextContent): if isinstance(content, TextContent):
yield self.create_text_message(content.text) yield self.create_text_message(content.text)

View File

@ -1461,6 +1461,10 @@ class AppMCPServer(Base):
return result return result
@property
def parameters_dict(self) -> dict[str, Any]:
return json.loads(self.parameters)
class Site(Base): class Site(Base):
__tablename__ = "sites" __tablename__ = "sites"

View File

@ -1,5 +1,6 @@
import json import json
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -58,8 +59,13 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None: if mcp_provider is None:
raise ValueError("MCP tool not found") raise ValueError("MCP tool not found")
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client: try:
tools = mcp_client.list_tools() with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True mcp_provider.authed = True
db.session.commit() db.session.commit()