diff --git a/src/server/app.py b/src/server/app.py index 07dd266..f18c70b 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -283,6 +283,13 @@ async def generate_prose(request: GenerateProseRequest): async def mcp_server_metadata(request: MCPServerMetadataRequest): """Get information about an MCP server.""" try: + # Set default timeout with a longer value for this endpoint + timeout = 300 # Default to 300 seconds for this endpoint + + # Use custom timeout from request if provided + if request.timeout_seconds is not None: + timeout = request.timeout_seconds + # Load tools from the MCP server using the utility function tools = await load_mcp_tools( server_type=request.transport, @@ -290,6 +297,7 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest): args=request.args, url=request.url, env=request.env, + timeout_seconds=timeout, ) # Create the response with tools diff --git a/src/server/mcp_request.py b/src/server/mcp_request.py index ee02acb..85a490b 100644 --- a/src/server/mcp_request.py +++ b/src/server/mcp_request.py @@ -22,6 +22,9 @@ class MCPServerMetadataRequest(BaseModel): None, description="The URL of the SSE server (for sse type)" ) env: Optional[Dict[str, str]] = Field(None, description="Environment variables") + timeout_seconds: Optional[int] = Field( + None, description="Optional custom timeout in seconds for the operation" + ) class MCPServerMetadataResponse(BaseModel): diff --git a/src/server/mcp_utils.py b/src/server/mcp_utils.py index 62a8243..c6f3b2c 100644 --- a/src/server/mcp_utils.py +++ b/src/server/mcp_utils.py @@ -13,12 +13,13 @@ from mcp.client.sse import sse_client logger = logging.getLogger(__name__) -async def _get_tools_from_client_session(client_context_manager: Any) -> List: +async def _get_tools_from_client_session(client_context_manager: Any, timeout_seconds: int = 10) -> List: """ Helper function to get tools from a client session. Args: client_context_manager: A context manager that returns (read, write) functions + timeout_seconds: Timeout in seconds for the read operation Returns: List of available tools from the MCP server @@ -28,7 +29,7 @@ async def _get_tools_from_client_session(client_context_manager: Any) -> List: """ async with client_context_manager as (read, write): async with ClientSession( - read, write, read_timeout_seconds=timedelta(seconds=10) + read, write, read_timeout_seconds=timedelta(seconds=timeout_seconds) ) as session: # Initialize the connection await session.initialize() @@ -43,6 +44,7 @@ async def load_mcp_tools( args: Optional[List[str]] = None, url: Optional[str] = None, env: Optional[Dict[str, str]] = None, + timeout_seconds: int = 60, # Longer default timeout for first-time executions ) -> List: """ Load tools from an MCP server. @@ -53,6 +55,7 @@ async def load_mcp_tools( args: Command arguments (for stdio type) url: The URL of the SSE server (for sse type) env: Environment variables + timeout_seconds: Timeout in seconds (default: 60 for first-time executions) Returns: List of available tools from the MCP server @@ -73,7 +76,7 @@ async def load_mcp_tools( env=env, # Optional environment variables ) - return await _get_tools_from_client_session(stdio_client(server_params)) + return await _get_tools_from_client_session(stdio_client(server_params), timeout_seconds) elif server_type == "sse": if not url: @@ -81,7 +84,7 @@ async def load_mcp_tools( status_code=400, detail="URL is required for sse type" ) - return await _get_tools_from_client_session(sse_client(url=url)) + return await _get_tools_from_client_session(sse_client(url=url), timeout_seconds) else: raise HTTPException(