mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 06:45:58 +08:00
feat: upgrade streamable http client
This commit is contained in:
parent
1fd4839eca
commit
41bbcb9540
@ -68,16 +68,31 @@ class AppMCPServerController(Resource):
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise Forbidden()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
server.status = AppMCPServerStatus(args["status"])
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
|
||||
if not server:
|
||||
raise Forbidden()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
|
@ -18,6 +18,7 @@ from controllers.web.error import (
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -175,11 +176,30 @@ class ChatMCPApi(Resource):
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
raise NotFound("App Not Found")
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
try:
|
||||
user_input_form = [VariableEntity.model_validate(item) for item in user_input_form]
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid user_input_form: {str(e)}")
|
||||
try:
|
||||
request = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid MCP request: {str(e)}")
|
||||
mcp_server_handler = MCPServerReuqestHandler(app, request)
|
||||
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
|
||||
return helper.compact_generate_response(mcp_server_handler.handle())
|
||||
|
||||
|
||||
|
@ -127,6 +127,7 @@ def create_ssrf_proxy_mcp_http_client(
|
||||
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
"headers": headers or {},
|
||||
"timeout": timeout,
|
||||
"follow_redirects": True, # Enable redirect following for MCP connections
|
||||
}
|
||||
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
|
@ -59,7 +59,6 @@ def start_authorization(
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
redirect_url: str,
|
||||
scope: Optional[str] = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow."""
|
||||
response_type = "code"
|
||||
@ -85,11 +84,9 @@ def start_authorization(
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"redirect_uri": redirect_url,
|
||||
"state": "/tools?provider_id=" + client_information.client_id,
|
||||
}
|
||||
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
@ -187,7 +184,6 @@ def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
@ -233,7 +229,6 @@ def auth(
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
scope or provider.client_metadata.scope,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
dify_config = DifyConfig()
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
provider_id: str
|
||||
@ -25,7 +23,7 @@ class OAuthClientProvider:
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_WEB_URL
|
||||
return dify_config.CONSOLE_WEB_URL + "/tools"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
@ -37,7 +35,6 @@ class OAuthClientProvider:
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
scope="read write",
|
||||
)
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
@ -91,7 +88,3 @@ class OAuthClientProvider:
|
||||
if not mcp_provider:
|
||||
return ""
|
||||
return mcp_provider.credentials.get("code_verifier", "")
|
||||
|
||||
|
||||
class UnauthorizedError(Exception):
|
||||
pass
|
||||
|
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
@ -42,7 +41,7 @@ def sse_client(
|
||||
read_queue = queue.Queue()
|
||||
write_queue = queue.Queue()
|
||||
status_queue = queue.Queue()
|
||||
cancel_event = threading.Event()
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
@ -51,14 +50,10 @@ def sse_client(
|
||||
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
def sse_reader(status_queue: queue.Queue):
|
||||
try:
|
||||
while not cancel_event.is_set():
|
||||
for sse in event_source.iter_sse():
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
@ -74,7 +69,7 @@ def sse_client(
|
||||
f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
status_queue.put(("error", ValueError(error_msg)))
|
||||
status_queue.put(("ready", endpoint_url))
|
||||
case "message":
|
||||
try:
|
||||
@ -88,17 +83,16 @@ def sse_client(
|
||||
read_queue.put(session_message)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
if not cancel_event.is_set():
|
||||
logger.exception("Error reading SSE messages")
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def post_writer(endpoint_url: str):
|
||||
try:
|
||||
while not cancel_event.is_set():
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
@ -113,14 +107,13 @@ def sse_client(
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
except queue.Empty:
|
||||
if cancel_event.is_set():
|
||||
break
|
||||
continue
|
||||
except Exception:
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
@ -131,11 +124,12 @@ def sse_client(
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
if status != "ready":
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
if status == "error":
|
||||
raise endpoint_url
|
||||
executor.submit(post_writer, endpoint_url)
|
||||
try:
|
||||
|
||||
yield read_queue, write_queue
|
||||
finally:
|
||||
cancel_event.set()
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
|
@ -8,7 +8,6 @@ and session management.
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
@ -106,11 +105,6 @@ class StreamableHTTPTransport:
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
def stop(self):
|
||||
"""Signal to stop all operations."""
|
||||
self.stop_event.set()
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
@ -170,6 +164,9 @@ class StreamableHTTPTransport:
|
||||
# Put exception in queue that goes to client
|
||||
server_to_client_queue.put(exc)
|
||||
return False
|
||||
elif sse.event == "ping":
|
||||
logger.debug("Received ping event")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
@ -198,8 +195,6 @@ class StreamableHTTPTransport:
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
|
||||
except Exception as exc:
|
||||
@ -230,8 +225,6 @@ class StreamableHTTPTransport:
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
@ -300,13 +293,13 @@ class StreamableHTTPTransport:
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
self._handle_sse_event(
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
@ -346,7 +339,7 @@ class StreamableHTTPTransport:
|
||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||
Responses are written to the server_to_client_queue.
|
||||
"""
|
||||
while not self.stop_event.is_set():
|
||||
while True:
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
@ -382,10 +375,8 @@ class StreamableHTTPTransport:
|
||||
else:
|
||||
self._handle_post_request(ctx)
|
||||
except queue.Empty:
|
||||
# Timeout - continue loop to check stop_event
|
||||
continue
|
||||
except Exception as exc:
|
||||
# Send exception to client
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client) -> None:
|
||||
@ -478,9 +469,6 @@ def streamablehttp_client(
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clean up
|
||||
transport.stop()
|
||||
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
|
@ -21,7 +21,6 @@ class MCPClient:
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
@ -32,7 +31,6 @@ class MCPClient:
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
self.scope = scope
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
@ -49,7 +47,7 @@ class MCPClient:
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
self._initialize(first_try=True)
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
@ -58,7 +56,6 @@ class MCPClient:
|
||||
|
||||
def _initialize(
|
||||
self,
|
||||
first_try: bool = True,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
|
||||
@ -71,9 +68,9 @@ class MCPClient:
|
||||
self.connect_server(client_factory, method_name)
|
||||
except KeyError:
|
||||
try:
|
||||
self.connect_server(streamablehttp_client, "sse")
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
self.connect_server(sse_client, "mcp")
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
@ -100,8 +97,8 @@ class MCPClient:
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
|
||||
auth(self.provider, self.server_url, self.authorization_code, self.scope)
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
@ -134,5 +131,6 @@ class MCPClient:
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
self.exit_stack.close()
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
|
@ -2,30 +2,31 @@ import json
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.model import App, AppMCPServer, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
dify_config = DifyConfig()
|
||||
|
||||
|
||||
class MCPServerReuqestHandler:
|
||||
def __init__(self, app: App, request: types.ClientRequest):
|
||||
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
|
||||
self.app = app
|
||||
self.request = request
|
||||
if not self.app.mcp_server:
|
||||
self.mcp_server: AppMCPServer = self.app.mcp_server
|
||||
if not self.mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server = self.app.mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
@ -33,6 +34,7 @@ class MCPServerReuqestHandler:
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -41,10 +43,11 @@ class MCPServerReuqestHandler:
|
||||
"type": "object",
|
||||
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
|
||||
"default": {},
|
||||
# TODO: add input parameters
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
"required": "query",
|
||||
}
|
||||
|
||||
@property
|
||||
@ -152,3 +155,25 @@ class MCPServerReuqestHandler:
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"]
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "number"
|
||||
return parameters, required
|
||||
|
@ -1,7 +1,5 @@
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import ExitStack
|
||||
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
self._completed = False
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
self._cancel_event = threading.Event()
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
self._cancel_event = threading.Event()
|
||||
self._cancel_event.clear()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
if not self._cancel_event:
|
||||
raise RuntimeError("No active cancel scope")
|
||||
self._cancel_event.set()
|
||||
|
||||
def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
"""Send a response for this request.
|
||||
@ -117,7 +109,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
if not self.cancelled:
|
||||
self._completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
@ -127,7 +118,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
|
||||
self._cancel_event.set()
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
@property
|
||||
def in_flight(self) -> bool:
|
||||
return not self._completed and not self.cancelled
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
@ -184,11 +166,9 @@ class BaseSession(
|
||||
self._in_flight = {}
|
||||
self._exit_stack = ExitStack()
|
||||
self._futures = []
|
||||
self._request_id_lock = threading.Lock()
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._executor = ThreadPoolExecutor()
|
||||
self._stop_event = threading.Event()
|
||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||
return self
|
||||
|
||||
@ -196,21 +176,8 @@ class BaseSession(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self._exit_stack.close()
|
||||
self._stop_event.set()
|
||||
self._wait_for_futures(timeout=5)
|
||||
|
||||
def _wait_for_futures(self, timeout=None):
|
||||
end_time = time.time() + timeout if timeout else None
|
||||
|
||||
for future in list(self._futures):
|
||||
try:
|
||||
remaining = end_time - time.time() if end_time else None
|
||||
if remaining is not None and remaining <= 0:
|
||||
break
|
||||
|
||||
future.result(timeout=remaining)
|
||||
except Exception as e:
|
||||
logging.exception(f"Error waiting for task: {e}")
|
||||
self._read_stream.put(None)
|
||||
self._write_stream.put(None)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
@ -247,7 +214,7 @@ class BaseSession(
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
while not self._stop_event.is_set():
|
||||
while True:
|
||||
try:
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
@ -316,7 +283,7 @@ class BaseSession(
|
||||
Main message processing loop.
|
||||
In a real synchronous implementation, this would likely run in a separate thread.
|
||||
"""
|
||||
while not self._stop_event.is_set():
|
||||
while True:
|
||||
try:
|
||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
@ -378,12 +345,9 @@ class BaseSession(
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
except queue.Empty:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception("Error in message processing loop")
|
||||
self._stop_event.set()
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
|
@ -3,12 +3,11 @@ from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from configs import dify_config
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
|
||||
dify_config = DifyConfig()
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)
|
||||
|
||||
|
||||
|
@ -1177,7 +1177,6 @@ class SessionMessage:
|
||||
class OAuthClientMetadata(BaseModel):
|
||||
client_name: str
|
||||
redirect_uris: list[str]
|
||||
scope: str
|
||||
grant_types: Optional[list[str]] = None
|
||||
response_types: Optional[list[str]] = None
|
||||
token_endpoint_auth_method: Optional[str] = None
|
||||
|
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
optional_fields = self.optional_field("server_url", self.server_url)
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel):
|
||||
"allow_delete": self.allow_delete,
|
||||
"tools": tools,
|
||||
"labels": self.labels,
|
||||
**optional_fields,
|
||||
}
|
||||
|
||||
def optional_field(self, key: str, value: Any) -> dict:
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
@ -1,6 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||
@ -37,9 +38,14 @@ class MCPTool(Tool):
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
try:
|
||||
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
|
||||
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ValueError("Please auth the tool first")
|
||||
except MCPConnectionError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield self.create_text_message(content.text)
|
||||
|
@ -1461,6 +1461,10 @@ class AppMCPServer(Base):
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def parameters_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.parameters)
|
||||
|
||||
|
||||
class Site(Base):
|
||||
__tablename__ = "sites"
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -58,8 +59,13 @@ class MCPToolManageService:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
try:
|
||||
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError as e:
|
||||
raise ValueError("Please auth the tool first")
|
||||
except MCPConnectionError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
mcp_provider.authed = True
|
||||
db.session.commit()
|
||||
|
Loading…
x
Reference in New Issue
Block a user