feat: upgrade streamable http client

This commit is contained in:
Novice 2025-05-27 13:14:51 +08:00
parent 1fd4839eca
commit 41bbcb9540
16 changed files with 167 additions and 155 deletions

View File

@ -68,16 +68,31 @@ class AppMCPServerController(Resource):
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=True, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server:
raise Forbidden()
server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
server.status = AppMCPServerStatus(args["status"])
db.session.commit()
return server
class AppMCPServerRefreshController(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
def get(self, server_id):
if not current_user.is_editor:
raise Forbidden()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
if not server:
raise Forbidden()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")

View File

@ -18,6 +18,7 @@ from controllers.web.error import (
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from controllers.web.wraps import WebApiResource
from core.app.app_config.entities import VariableEntity
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@ -175,11 +176,30 @@ class ChatMCPApi(Resource):
app = db.session.query(App).filter(App.id == server.app_id).first()
if not app:
raise NotFound("App Not Found")
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
try:
user_input_form = [VariableEntity.model_validate(item) for item in user_input_form]
except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}")
try:
request = ClientRequest.model_validate(args)
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request)
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
return helper.compact_generate_response(mcp_server_handler.handle())

View File

@ -127,6 +127,7 @@ def create_ssrf_proxy_mcp_http_client(
"verify": HTTP_REQUEST_NODE_SSL_VERIFY,
"headers": headers or {},
"timeout": timeout,
"follow_redirects": True, # Enable redirect following for MCP connections
}
if dify_config.SSRF_PROXY_ALL_URL:

View File

@ -59,7 +59,6 @@ def start_authorization(
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
redirect_url: str,
scope: Optional[str] = None,
) -> tuple[str, str]:
"""Begins the authorization flow."""
response_type = "code"
@ -85,11 +84,9 @@ def start_authorization(
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": "/tools?provider_id=" + client_information.client_id,
}
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
@ -187,7 +184,6 @@ def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server."""
metadata = discover_oauth_metadata(server_url)
@ -233,7 +229,6 @@ def auth(
metadata,
client_information,
provider.redirect_url,
scope or provider.client_metadata.scope,
)
provider.save_code_verifier(code_verifier)

View File

@ -1,6 +1,6 @@
from typing import Optional
from configs.app_config import DifyConfig
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
dify_config = DifyConfig()
class OAuthClientProvider:
provider_id: str
@ -25,7 +23,7 @@ class OAuthClientProvider:
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL
return dify_config.CONSOLE_WEB_URL + "/tools"
@property
def client_metadata(self) -> OAuthClientMetadata:
@ -37,7 +35,6 @@ class OAuthClientProvider:
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
scope="read write",
)
def client_information(self) -> Optional[OAuthClientInformation]:
@ -91,7 +88,3 @@ class OAuthClientProvider:
if not mcp_provider:
return ""
return mcp_provider.credentials.get("code_verifier", "")
class UnauthorizedError(Exception):
pass

View File

@ -1,6 +1,5 @@
import logging
import queue
import threading
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -42,7 +41,7 @@ def sse_client(
read_queue = queue.Queue()
write_queue = queue.Queue()
status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
@ -51,54 +50,49 @@ def sse_client(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue):
try:
while not cancel_event.is_set():
for sse in event_source.iter_sse():
if cancel_event.is_set():
break
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
for sse in event_source.iter_sse():
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
if not cancel_event.is_set():
logger.exception("Error reading SSE messages")
read_queue.put(exc)
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while not cancel_event.is_set():
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
@ -113,14 +107,13 @@ def sse_client(
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty:
if cancel_event.is_set():
break
continue
except Exception:
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
@ -131,11 +124,12 @@ def sse_client(
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error":
raise endpoint_url
executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue
finally:
cancel_event.set()
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()

View File

@ -8,7 +8,6 @@ and session management.
import logging
import queue
import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -106,11 +105,6 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON,
**self.headers,
}
self.stop_event = threading.Event()
def stop(self):
"""Signal to stop all operations."""
self.stop_event.set()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
@ -170,6 +164,9 @@ class StreamableHTTPTransport:
# Put exception in queue that goes to client
server_to_client_queue.put(exc)
return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else:
logger.warning(f"Unknown SSE event: {sse.event}")
return False
@ -198,8 +195,6 @@ class StreamableHTTPTransport:
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc:
@ -230,8 +225,6 @@ class StreamableHTTPTransport:
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
@ -300,13 +293,13 @@ class StreamableHTTPTransport:
try:
event_source = EventSource(response)
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
except Exception as e:
ctx.server_to_client_queue.put(e)
@ -346,7 +339,7 @@ class StreamableHTTPTransport:
This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue.
"""
while not self.stop_event.is_set():
while True:
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
@ -382,10 +375,8 @@ class StreamableHTTPTransport:
else:
self._handle_post_request(ctx)
except queue.Empty:
# Timeout - continue loop to check stop_event
continue
except Exception as exc:
# Send exception to client
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None:
@ -478,9 +469,6 @@ def streamablehttp_client(
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clean up
transport.stop()
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():

View File

@ -21,7 +21,6 @@ class MCPClient:
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
):
# Initialize info
self.provider_id = provider_id
@ -32,7 +31,6 @@ class MCPClient:
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
self.scope = scope
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
@ -49,7 +47,7 @@ class MCPClient:
self._initialized = False
def __enter__(self):
self._initialize(first_try=True)
self._initialize()
self._initialized = True
return self
@ -58,7 +56,6 @@ class MCPClient:
def _initialize(
self,
first_try: bool = True,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
@ -71,9 +68,9 @@ class MCPClient:
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(streamablehttp_client, "sse")
self.connect_server(sse_client, "sse")
except MCPConnectionError:
self.connect_server(sse_client, "mcp")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth
@ -100,8 +97,8 @@ class MCPClient:
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
auth(self.provider, self.server_url, self.authorization_code)
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
@ -134,5 +131,6 @@ class MCPClient:
self._session = None
self._initialized = False
self.exit_stack.close()
except Exception:
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,30 +2,31 @@ import json
from collections.abc import Mapping
from typing import cast
from configs.app_config import DifyConfig
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, EndUser
from models.model import App, AppMCPServer, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
dify_config = DifyConfig()
class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest):
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
self.app = app
self.request = request
if not self.app.mcp_server:
self.mcp_server: AppMCPServer = self.app.mcp_server
if not self.mcp_server:
raise ValueError("MCP server not found")
self.mcp_server = self.app.mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@property
def request_type(self):
@ -33,6 +34,7 @@ class MCPServerReuqestHandler:
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
return {
"type": "object",
"properties": {
@ -41,10 +43,11 @@ class MCPServerReuqestHandler:
"type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {},
# TODO: add input parameters
"properties": parameters,
"required": required,
},
},
"required": ["query"],
"required": "query",
}
@property
@ -152,3 +155,25 @@ class MCPServerReuqestHandler:
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"]
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
return parameters, required

View File

@ -1,7 +1,5 @@
import logging
import queue
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
self._cancel_event = threading.Event()
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_event = threading.Event()
self._cancel_event.clear()
return self
def __exit__(
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_event:
raise RuntimeError("No active cancel scope")
self._cancel_event.set()
def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
@ -117,17 +109,15 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
self._completed = True
self._session._send_response(request_id=self.request_id, response=response)
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
self._cancel_event.set()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_event.is_set()
class BaseSession(
Generic[
@ -184,11 +166,9 @@ class BaseSession(
self._in_flight = {}
self._exit_stack = ExitStack()
self._futures = []
self._request_id_lock = threading.Lock()
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
self._stop_event = threading.Event()
self._receiver_future = self._executor.submit(self._receive_loop)
return self
@ -196,21 +176,8 @@ class BaseSession(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self._exit_stack.close()
self._stop_event.set()
self._wait_for_futures(timeout=5)
def _wait_for_futures(self, timeout=None):
end_time = time.time() + timeout if timeout else None
for future in list(self._futures):
try:
remaining = end_time - time.time() if end_time else None
if remaining is not None and remaining <= 0:
break
future.result(timeout=remaining)
except Exception as e:
logging.exception(f"Error waiting for task: {e}")
self._read_stream.put(None)
self._write_stream.put(None)
def send_request(
self,
@ -247,7 +214,7 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
while not self._stop_event.is_set():
while True:
try:
response_or_error = response_queue.get(timeout=timeout)
break
@ -316,7 +283,7 @@ class BaseSession(
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
"""
while not self._stop_event.is_set():
while True:
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
@ -378,12 +345,9 @@ class BaseSession(
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
except queue.Empty:
if self._stop_event.is_set():
break
continue
except Exception as e:
logging.exception("Error in message processing loop")
self._stop_event.set()
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""

View File

@ -3,12 +3,11 @@ from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from configs.app_config import DifyConfig
from configs import dify_config
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
dify_config = DifyConfig()
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)

View File

@ -1177,7 +1177,6 @@ class SessionMessage:
class OAuthClientMetadata(BaseModel):
client_name: str
redirect_uris: list[str]
scope: str
grant_types: Optional[list[str]] = None
response_types: Optional[list[str]] = None
token_endpoint_auth_method: Optional[str] = None

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Literal, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
@ -57,7 +57,7 @@ class ToolProviderApiEntity(BaseModel):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
optional_fields = self.optional_field("server_url", self.server_url)
return {
"id": self.id,
"author": self.author,
@ -73,4 +73,9 @@ class ToolProviderApiEntity(BaseModel):
"allow_delete": self.allow_delete,
"tools": tools,
"labels": self.labels,
**optional_fields,
}
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}

View File

@ -1,6 +1,7 @@
from collections.abc import Generator
from typing import Any, Optional
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.mcp.types import ImageContent, TextContent
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -37,9 +38,14 @@ class MCPTool(Tool):
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
try:
with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
tool_parameters = convert_parameters_to_plugin_format(tool_parameters)
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
for content in result.content:
if isinstance(content, TextContent):
yield self.create_text_message(content.text)

View File

@ -1461,6 +1461,10 @@ class AppMCPServer(Base):
return result
@property
def parameters_dict(self) -> dict[str, Any]:
return json.loads(self.parameters)
class Site(Base):
__tablename__ = "sites"

View File

@ -1,5 +1,6 @@
import json
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
@ -58,8 +59,13 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
try:
with MCPClient(mcp_provider.server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
except MCPConnectionError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
db.session.commit()