feat: implement tools loading api

This commit is contained in:
He Tao 2025-04-23 14:38:04 +08:00
parent 8129370bbd
commit dae036f583
7 changed files with 265 additions and 71 deletions

View File

@ -10,7 +10,7 @@ lint:
uv run black --check .
serve:
uv run server.py
uv run server.py --reload
test:
uv run pytest tests/

View File

@ -11,10 +11,8 @@ For stdio type:
{
"type": "stdio",
"command": "npx",
"args": ["@agentdeskai/browser-tools-mcp@1.2.0"]
"env": {
"MCP_SERVER_ID": "mcp-github-trending"
}
"args": ["-y", "tavily-mcp@0.1.3"],
"env": {"TAVILY_API_KEY": "tvly-dev-xxx"}
}
```

View File

@ -30,6 +30,7 @@ dependencies = [
"duckduckgo-search>=8.0.0",
"inquirerpy>=0.3.4",
"arxiv>=2.2.0",
"mcp>=1.6.0",
]
[project.optional-dependencies]

View File

@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage
from langgraph.types import Command
from mcp import ClientSession
from src.graph.builder import build_graph_with_memory
from src.podcast.graph.builder import build_graph as build_podcast_graph
@ -24,6 +25,8 @@ from src.server.chat_request import (
GeneratePPTRequest,
TTSRequest,
)
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
from src.server.mcp_utils import load_mcp_tools
from src.tools import VolcengineTTS
logger = logging.getLogger(__name__)
@ -244,3 +247,34 @@ async def generate_ppt(request: GeneratePPTRequest):
except Exception as e:
logger.exception(f"Error occurred during ppt generation: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
async def mcp_server_metadata(request: MCPServerMetadataRequest):
"""Get information about an MCP server."""
try:
# Load tools from the MCP server using the utility function
tools = await load_mcp_tools(
server_type=request.type,
command=request.command,
args=request.args,
url=request.url,
env=request.env,
)
# Create the response with tools
response = MCPServerMetadataResponse(
type=request.type,
command=request.command,
args=request.args,
url=request.url,
env=request.env,
tools=tools,
)
return response
except Exception as e:
if not isinstance(e, HTTPException):
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise

45
src/server/mcp_request.py Normal file
View File

@ -0,0 +1,45 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
class MCPServerMetadataRequest(BaseModel):
"""Request model for MCP server metadata."""
type: str = Field(
..., description="The type of MCP server connection (stdio or sse)"
)
command: Optional[str] = Field(
None, description="The command to execute (for stdio type)"
)
args: Optional[List[str]] = Field(
None, description="Command arguments (for stdio type)"
)
url: Optional[str] = Field(
None, description="The URL of the SSE server (for sse type)"
)
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
class MCPServerMetadataResponse(BaseModel):
"""Response model for MCP server metadata."""
type: str = Field(
..., description="The type of MCP server connection (stdio or sse)"
)
command: Optional[str] = Field(
None, description="The command to execute (for stdio type)"
)
args: Optional[List[str]] = Field(
None, description="Command arguments (for stdio type)"
)
url: Optional[str] = Field(
None, description="The URL of the SSE server (for sse type)"
)
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
tools: List = Field(
default_factory=list, description="Available tools from the MCP server"
)

95
src/server/mcp_utils.py Normal file
View File

@ -0,0 +1,95 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
from datetime import timedelta
from typing import Any, Dict, List, Optional, Tuple
from fastapi import HTTPException
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
logger = logging.getLogger(__name__)
async def _get_tools_from_client_session(client_context_manager: Any) -> List:
"""
Helper function to get tools from a client session.
Args:
client_context_manager: A context manager that returns (read, write) functions
Returns:
List of available tools from the MCP server
Raises:
Exception: If there's an error during the process
"""
async with client_context_manager as (read, write):
async with ClientSession(
read, write, read_timeout_seconds=timedelta(seconds=10)
) as session:
# Initialize the connection
await session.initialize()
# List available tools
listed_tools = await session.list_tools()
return listed_tools.tools
async def load_mcp_tools(
server_type: str,
command: Optional[str] = None,
args: Optional[List[str]] = None,
url: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
) -> List:
"""
Load tools from an MCP server.
Args:
server_type: The type of MCP server connection (stdio or sse)
command: The command to execute (for stdio type)
args: Command arguments (for stdio type)
url: The URL of the SSE server (for sse type)
env: Environment variables
Returns:
List of available tools from the MCP server
Raises:
HTTPException: If there's an error loading the tools
"""
try:
if server_type == "stdio":
if not command:
raise HTTPException(
status_code=400, detail="Command is required for stdio type"
)
server_params = StdioServerParameters(
command=command, # Executable
args=args, # Optional command line arguments
env=env, # Optional environment variables
)
return await _get_tools_from_client_session(stdio_client(server_params))
elif server_type == "sse":
if not url:
raise HTTPException(
status_code=400, detail="URL is required for sse type"
)
return await _get_tools_from_client_session(sse_client(url=url))
else:
raise HTTPException(
status_code=400, detail=f"Unsupported server type: {server_type}"
)
except Exception as e:
if not isinstance(e, HTTPException):
logger.exception(f"Error loading MCP tools: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise

153
uv.lock generated
View File

@ -309,6 +309,74 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686 },
]
[[package]]
name = "deer-flow"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "arxiv" },
{ name = "duckduckgo-search" },
{ name = "fastapi" },
{ name = "httpx" },
{ name = "inquirerpy" },
{ name = "jinja2" },
{ name = "json-repair" },
{ name = "langchain-community" },
{ name = "langchain-experimental" },
{ name = "langchain-openai" },
{ name = "langgraph" },
{ name = "litellm" },
{ name = "markdownify" },
{ name = "mcp" },
{ name = "numpy" },
{ name = "pandas" },
{ name = "python-dotenv" },
{ name = "readabilipy" },
{ name = "socksio" },
{ name = "sse-starlette" },
{ name = "uvicorn" },
{ name = "yfinance" },
]
[package.optional-dependencies]
dev = [
{ name = "black" },
]
test = [
{ name = "pytest" },
{ name = "pytest-cov" },
]
[package.metadata]
requires-dist = [
{ name = "arxiv", specifier = ">=2.2.0" },
{ name = "black", marker = "extra == 'dev'", specifier = ">=24.2.0" },
{ name = "duckduckgo-search", specifier = ">=8.0.0" },
{ name = "fastapi", specifier = ">=0.110.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "inquirerpy", specifier = ">=0.3.4" },
{ name = "jinja2", specifier = ">=3.1.3" },
{ name = "json-repair", specifier = ">=0.7.0" },
{ name = "langchain-community", specifier = ">=0.3.19" },
{ name = "langchain-experimental", specifier = ">=0.3.4" },
{ name = "langchain-openai", specifier = ">=0.3.8" },
{ name = "langgraph", specifier = ">=0.3.5" },
{ name = "litellm", specifier = ">=1.63.11" },
{ name = "markdownify", specifier = ">=1.1.0" },
{ name = "mcp", specifier = ">=1.6.0" },
{ name = "numpy", specifier = ">=2.2.3" },
{ name = "pandas", specifier = ">=2.2.3" },
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" },
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" },
{ name = "python-dotenv", specifier = ">=1.0.1" },
{ name = "readabilipy", specifier = ">=0.3.0" },
{ name = "socksio", specifier = ">=1.0.0" },
{ name = "sse-starlette", specifier = ">=1.6.5" },
{ name = "uvicorn", specifier = ">=0.27.1" },
{ name = "yfinance", specifier = ">=0.2.54" },
]
provides-extras = ["dev", "test"]
[[package]]
name = "distro"
version = "1.9.0"
@ -853,72 +921,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/09/3f909694aa0b104a611444959227832206864d92703e191a0f4b2a27d55b/langsmith-0.3.13-py3-none-any.whl", hash = "sha256:73aaf52bbc293b9415fff4f6dad68df40658081eb26c9cb2c7bd1ff57cedd695", size = 339683 },
]
[[package]]
name = "deer-flow"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "arxiv" },
{ name = "duckduckgo-search" },
{ name = "fastapi" },
{ name = "httpx" },
{ name = "inquirerpy" },
{ name = "jinja2" },
{ name = "json-repair" },
{ name = "langchain-community" },
{ name = "langchain-experimental" },
{ name = "langchain-openai" },
{ name = "langgraph" },
{ name = "litellm" },
{ name = "markdownify" },
{ name = "numpy" },
{ name = "pandas" },
{ name = "python-dotenv" },
{ name = "readabilipy" },
{ name = "socksio" },
{ name = "sse-starlette" },
{ name = "uvicorn" },
{ name = "yfinance" },
]
[package.optional-dependencies]
dev = [
{ name = "black" },
]
test = [
{ name = "pytest" },
{ name = "pytest-cov" },
]
[package.metadata]
requires-dist = [
{ name = "arxiv", specifier = ">=2.2.0" },
{ name = "black", marker = "extra == 'dev'", specifier = ">=24.2.0" },
{ name = "duckduckgo-search", specifier = ">=8.0.0" },
{ name = "fastapi", specifier = ">=0.110.0" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "inquirerpy", specifier = ">=0.3.4" },
{ name = "jinja2", specifier = ">=3.1.3" },
{ name = "json-repair", specifier = ">=0.7.0" },
{ name = "langchain-community", specifier = ">=0.3.19" },
{ name = "langchain-experimental", specifier = ">=0.3.4" },
{ name = "langchain-openai", specifier = ">=0.3.8" },
{ name = "langgraph", specifier = ">=0.3.5" },
{ name = "litellm", specifier = ">=1.63.11" },
{ name = "markdownify", specifier = ">=1.1.0" },
{ name = "numpy", specifier = ">=2.2.3" },
{ name = "pandas", specifier = ">=2.2.3" },
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" },
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" },
{ name = "python-dotenv", specifier = ">=1.0.1" },
{ name = "readabilipy", specifier = ">=0.3.0" },
{ name = "socksio", specifier = ">=1.0.0" },
{ name = "sse-starlette", specifier = ">=1.6.5" },
{ name = "uvicorn", specifier = ">=0.27.1" },
{ name = "yfinance", specifier = ">=0.2.54" },
]
provides-extras = ["dev", "test"]
[[package]]
name = "litellm"
version = "1.63.11"
@ -1046,6 +1048,25 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/34/75/51952c7b2d3873b44a0028b1bd26a25078c18f92f256608e8d1dc61b39fd/marshmallow-3.26.1-py3-none-any.whl", hash = "sha256:3350409f20a70a7e4e11a27661187b77cdcaeb20abca41c1454fe33636bea09c", size = 50878 },
]
[[package]]
name = "mcp"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "httpx" },
{ name = "httpx-sse" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "sse-starlette" },
{ name = "starlette" },
{ name = "uvicorn" },
]
sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 },
]
[[package]]
name = "msgpack"
version = "1.1.0"