mirror of
https://git.mirrors.martin98.com/https://github.com/bytedance/deer-flow
synced 2025-08-20 00:49:05 +08:00
feat: implement tools loading api
This commit is contained in:
parent
8129370bbd
commit
dae036f583
2
Makefile
2
Makefile
@ -10,7 +10,7 @@ lint:
|
||||
uv run black --check .
|
||||
|
||||
serve:
|
||||
uv run server.py
|
||||
uv run server.py --reload
|
||||
|
||||
test:
|
||||
uv run pytest tests/
|
||||
|
@ -11,10 +11,8 @@ For stdio type:
|
||||
{
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["@agentdeskai/browser-tools-mcp@1.2.0"]
|
||||
"env": {
|
||||
"MCP_SERVER_ID": "mcp-github-trending"
|
||||
}
|
||||
"args": ["-y", "tavily-mcp@0.1.3"],
|
||||
"env": {"TAVILY_API_KEY": "tvly-dev-xxx"}
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -30,6 +30,7 @@ dependencies = [
|
||||
"duckduckgo-search>=8.0.0",
|
||||
"inquirerpy>=0.3.4",
|
||||
"arxiv>=2.2.0",
|
||||
"mcp>=1.6.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage
|
||||
from langgraph.types import Command
|
||||
from mcp import ClientSession
|
||||
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
@ -24,6 +25,8 @@ from src.server.chat_request import (
|
||||
GeneratePPTRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
||||
from src.server.mcp_utils import load_mcp_tools
|
||||
from src.tools import VolcengineTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -244,3 +247,34 @@ async def generate_ppt(request: GeneratePPTRequest):
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred during ppt generation: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
||||
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
"""Get information about an MCP server."""
|
||||
try:
|
||||
# Load tools from the MCP server using the utility function
|
||||
tools = await load_mcp_tools(
|
||||
server_type=request.type,
|
||||
command=request.command,
|
||||
args=request.args,
|
||||
url=request.url,
|
||||
env=request.env,
|
||||
)
|
||||
|
||||
# Create the response with tools
|
||||
response = MCPServerMetadataResponse(
|
||||
type=request.type,
|
||||
command=request.command,
|
||||
args=request.args,
|
||||
url=request.url,
|
||||
env=request.env,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
if not isinstance(e, HTTPException):
|
||||
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
||||
|
45
src/server/mcp_request.py
Normal file
45
src/server/mcp_request.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MCPServerMetadataRequest(BaseModel):
|
||||
"""Request model for MCP server metadata."""
|
||||
|
||||
type: str = Field(
|
||||
..., description="The type of MCP server connection (stdio or sse)"
|
||||
)
|
||||
command: Optional[str] = Field(
|
||||
None, description="The command to execute (for stdio type)"
|
||||
)
|
||||
args: Optional[List[str]] = Field(
|
||||
None, description="Command arguments (for stdio type)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
None, description="The URL of the SSE server (for sse type)"
|
||||
)
|
||||
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
|
||||
|
||||
|
||||
class MCPServerMetadataResponse(BaseModel):
|
||||
"""Response model for MCP server metadata."""
|
||||
|
||||
type: str = Field(
|
||||
..., description="The type of MCP server connection (stdio or sse)"
|
||||
)
|
||||
command: Optional[str] = Field(
|
||||
None, description="The command to execute (for stdio type)"
|
||||
)
|
||||
args: Optional[List[str]] = Field(
|
||||
None, description="Command arguments (for stdio type)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
None, description="The URL of the SSE server (for sse type)"
|
||||
)
|
||||
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
|
||||
tools: List = Field(
|
||||
default_factory=list, description="Available tools from the MCP server"
|
||||
)
|
95
src/server/mcp_utils.py
Normal file
95
src/server/mcp_utils.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_tools_from_client_session(client_context_manager: Any) -> List:
|
||||
"""
|
||||
Helper function to get tools from a client session.
|
||||
|
||||
Args:
|
||||
client_context_manager: A context manager that returns (read, write) functions
|
||||
|
||||
Returns:
|
||||
List of available tools from the MCP server
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error during the process
|
||||
"""
|
||||
async with client_context_manager as (read, write):
|
||||
async with ClientSession(
|
||||
read, write, read_timeout_seconds=timedelta(seconds=10)
|
||||
) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
# List available tools
|
||||
listed_tools = await session.list_tools()
|
||||
return listed_tools.tools
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
server_type: str,
|
||||
command: Optional[str] = None,
|
||||
args: Optional[List[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
) -> List:
|
||||
"""
|
||||
Load tools from an MCP server.
|
||||
|
||||
Args:
|
||||
server_type: The type of MCP server connection (stdio or sse)
|
||||
command: The command to execute (for stdio type)
|
||||
args: Command arguments (for stdio type)
|
||||
url: The URL of the SSE server (for sse type)
|
||||
env: Environment variables
|
||||
|
||||
Returns:
|
||||
List of available tools from the MCP server
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error loading the tools
|
||||
"""
|
||||
try:
|
||||
if server_type == "stdio":
|
||||
if not command:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Command is required for stdio type"
|
||||
)
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=command, # Executable
|
||||
args=args, # Optional command line arguments
|
||||
env=env, # Optional environment variables
|
||||
)
|
||||
|
||||
return await _get_tools_from_client_session(stdio_client(server_params))
|
||||
|
||||
elif server_type == "sse":
|
||||
if not url:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="URL is required for sse type"
|
||||
)
|
||||
|
||||
return await _get_tools_from_client_session(sse_client(url=url))
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Unsupported server type: {server_type}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, HTTPException):
|
||||
logger.exception(f"Error loading MCP tools: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
153
uv.lock
generated
153
uv.lock
generated
@ -309,6 +309,74 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deer-flow"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "arxiv" },
|
||||
{ name = "duckduckgo-search" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "inquirerpy" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "langchain-community" },
|
||||
{ name = "langchain-experimental" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "litellm" },
|
||||
{ name = "markdownify" },
|
||||
{ name = "mcp" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pandas" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "readabilipy" },
|
||||
{ name = "socksio" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "yfinance" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "black" },
|
||||
]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-cov" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "arxiv", specifier = ">=2.2.0" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=24.2.0" },
|
||||
{ name = "duckduckgo-search", specifier = ">=8.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.110.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "inquirerpy", specifier = ">=0.3.4" },
|
||||
{ name = "jinja2", specifier = ">=3.1.3" },
|
||||
{ name = "json-repair", specifier = ">=0.7.0" },
|
||||
{ name = "langchain-community", specifier = ">=0.3.19" },
|
||||
{ name = "langchain-experimental", specifier = ">=0.3.4" },
|
||||
{ name = "langchain-openai", specifier = ">=0.3.8" },
|
||||
{ name = "langgraph", specifier = ">=0.3.5" },
|
||||
{ name = "litellm", specifier = ">=1.63.11" },
|
||||
{ name = "markdownify", specifier = ">=1.1.0" },
|
||||
{ name = "mcp", specifier = ">=1.6.0" },
|
||||
{ name = "numpy", specifier = ">=2.2.3" },
|
||||
{ name = "pandas", specifier = ">=2.2.3" },
|
||||
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
||||
{ name = "readabilipy", specifier = ">=0.3.0" },
|
||||
{ name = "socksio", specifier = ">=1.0.0" },
|
||||
{ name = "sse-starlette", specifier = ">=1.6.5" },
|
||||
{ name = "uvicorn", specifier = ">=0.27.1" },
|
||||
{ name = "yfinance", specifier = ">=0.2.54" },
|
||||
]
|
||||
provides-extras = ["dev", "test"]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
@ -853,72 +921,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/09/3f909694aa0b104a611444959227832206864d92703e191a0f4b2a27d55b/langsmith-0.3.13-py3-none-any.whl", hash = "sha256:73aaf52bbc293b9415fff4f6dad68df40658081eb26c9cb2c7bd1ff57cedd695", size = 339683 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deer-flow"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "arxiv" },
|
||||
{ name = "duckduckgo-search" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "inquirerpy" },
|
||||
{ name = "jinja2" },
|
||||
{ name = "json-repair" },
|
||||
{ name = "langchain-community" },
|
||||
{ name = "langchain-experimental" },
|
||||
{ name = "langchain-openai" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "litellm" },
|
||||
{ name = "markdownify" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pandas" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "readabilipy" },
|
||||
{ name = "socksio" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "yfinance" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "black" },
|
||||
]
|
||||
test = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-cov" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "arxiv", specifier = ">=2.2.0" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=24.2.0" },
|
||||
{ name = "duckduckgo-search", specifier = ">=8.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.110.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.1" },
|
||||
{ name = "inquirerpy", specifier = ">=0.3.4" },
|
||||
{ name = "jinja2", specifier = ">=3.1.3" },
|
||||
{ name = "json-repair", specifier = ">=0.7.0" },
|
||||
{ name = "langchain-community", specifier = ">=0.3.19" },
|
||||
{ name = "langchain-experimental", specifier = ">=0.3.4" },
|
||||
{ name = "langchain-openai", specifier = ">=0.3.8" },
|
||||
{ name = "langgraph", specifier = ">=0.3.5" },
|
||||
{ name = "litellm", specifier = ">=1.63.11" },
|
||||
{ name = "markdownify", specifier = ">=1.1.0" },
|
||||
{ name = "numpy", specifier = ">=2.2.3" },
|
||||
{ name = "pandas", specifier = ">=2.2.3" },
|
||||
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
||||
{ name = "readabilipy", specifier = ">=0.3.0" },
|
||||
{ name = "socksio", specifier = ">=1.0.0" },
|
||||
{ name = "sse-starlette", specifier = ">=1.6.5" },
|
||||
{ name = "uvicorn", specifier = ">=0.27.1" },
|
||||
{ name = "yfinance", specifier = ">=0.2.54" },
|
||||
]
|
||||
provides-extras = ["dev", "test"]
|
||||
|
||||
[[package]]
|
||||
name = "litellm"
|
||||
version = "1.63.11"
|
||||
@ -1046,6 +1048,25 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/34/75/51952c7b2d3873b44a0028b1bd26a25078c18f92f256608e8d1dc61b39fd/marshmallow-3.26.1-py3-none-any.whl", hash = "sha256:3350409f20a70a7e4e11a27661187b77cdcaeb20abca41c1454fe33636bea09c", size = 50878 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "httpx" },
|
||||
{ name = "httpx-sse" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "starlette" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "msgpack"
|
||||
version = "1.1.0"
|
||||
|
Loading…
x
Reference in New Issue
Block a user