From 0cf859b39354b9c5f3074a33b2058c674affe928 Mon Sep 17 00:00:00 2001 From: He Tao Date: Wed, 23 Apr 2025 16:00:01 +0800 Subject: [PATCH] feat: support mcp settings --- docs/mcp_integrations.md | 5 +- main.py | 15 +++--- pyproject.toml | 1 + src/config/configuration.py | 1 + src/graph/nodes.py | 95 +++++++++++++++++++++++++++++++++---- src/prompts/researcher.md | 4 +- src/server/app.py | 7 ++- src/server/chat_request.py | 3 ++ src/server/mcp_request.py | 4 +- src/workflow.py | 24 ++++++++-- uv.lock | 15 ++++++ 11 files changed, 147 insertions(+), 27 deletions(-) diff --git a/docs/mcp_integrations.md b/docs/mcp_integrations.md index 59213d6..5148108 100644 --- a/docs/mcp_integrations.md +++ b/docs/mcp_integrations.md @@ -9,7 +9,7 @@ For stdio type: ```json { - "type": "stdio", + "transport": "stdio", "command": "npx", "args": ["-y", "tavily-mcp@0.1.3"], "env": {"TAVILY_API_KEY": "tvly-dev-xxx"} @@ -19,7 +19,7 @@ For stdio type: For SSE type: ```json { - "type": "sse", + "transport": "sse", "url": "http://localhost:3000/sse", "env": { "API_KEY": "value" @@ -37,6 +37,7 @@ For SSE type: "mcp_settings": { "servers": { "mcp-github-trending": { + "transport": "stdio", "command": "uvx", "args": ["mcp-github-trending"], "env": { diff --git a/main.py b/main.py index ed1356a..dc00c48 100644 --- a/main.py +++ b/main.py @@ -6,9 +6,10 @@ Entry point script for the Deer project. """ import argparse +import asyncio from InquirerPy import inquirer -from src.workflow import run_agent_workflow +from src.workflow import run_agent_workflow_async from src.config.questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN @@ -21,11 +22,13 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3): max_plan_iterations: Maximum number of plan iterations max_step_num: Maximum number of steps in a plan """ - run_agent_workflow( - user_input=question, - debug=debug, - max_plan_iterations=max_plan_iterations, - max_step_num=max_step_num, + asyncio.run( + run_agent_workflow_async( + user_input=question, + debug=debug, + max_plan_iterations=max_plan_iterations, + max_step_num=max_step_num, + ) ) diff --git a/pyproject.toml b/pyproject.toml index 8ede212..08e32a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "inquirerpy>=0.3.4", "arxiv>=2.2.0", "mcp>=1.6.0", + "langchain-mcp-adapters>=0.0.9", ] [project.optional-dependencies] diff --git a/src/config/configuration.py b/src/config/configuration.py index 8df7a7f..dd20949 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -14,6 +14,7 @@ class Configuration: max_plan_iterations: int = 2 # Maximum number of plan iterations max_step_num: int = 5 # Maximum number of steps in a plan + mcp_settings: dict = None # MCP settings @classmethod def from_runnable_config( diff --git a/src/graph/nodes.py b/src/graph/nodes.py index c585b1e..12e2db6 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -9,8 +9,16 @@ from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import tool from langgraph.types import Command, interrupt +from langchain_mcp_adapters.client import MultiServerMCPClient + +from src.agents.agents import coder_agent, research_agent, create_agent + +from src.tools import ( + crawl_tool, + web_search_tool, + python_repl_tool, +) -from src.agents.agents import coder_agent, research_agent from src.config.agents import AGENT_LLM_MAP from src.config.configuration import Configuration from src.llms.llm import get_llm_by_type @@ -171,7 +179,7 @@ def reporter_node(state: State): for observation in observations: invoke_messages.append( HumanMessage( - content=f"Below is some observations for the user query:\n\n{observation}", + content=f"Below are some observations for the research task:\n\n{observation}", name="observation", ) ) @@ -203,7 +211,7 @@ def research_team_node( return Command(goto="planner") -def _execute_agent_step( +async def _execute_agent_step( state: State, agent, agent_name: str ) -> Command[Literal["research_team"]]: """Helper function to execute a step using the specified agent.""" @@ -236,7 +244,7 @@ def _execute_agent_step( ) # Invoke the agent - result = agent.invoke(input=agent_input) + result = await agent.ainvoke(input=agent_input) # Process the result response_content = result["messages"][-1].content @@ -260,13 +268,84 @@ def _execute_agent_step( ) -def researcher_node(state: State) -> Command[Literal["research_team"]]: +async def _setup_and_execute_agent_step( + state: State, + config: RunnableConfig, + agent_type: str, + default_agent, + default_tools: list, +) -> Command[Literal["research_team"]]: + """Helper function to set up an agent with appropriate tools and execute a step. + + This function handles the common logic for both researcher_node and coder_node: + 1. Configures MCP servers and tools based on agent type + 2. Creates an agent with the appropriate tools or uses the default agent + 3. Executes the agent on the current step + + Args: + state: The current state + config: The runnable config + agent_type: The type of agent ("researcher" or "coder") + default_agent: The default agent to use if no MCP servers are configured + default_tools: The default tools to add to the agent + + Returns: + Command to update state and go to research_team + """ + configurable = Configuration.from_runnable_config(config) + mcp_servers = {} + enabled_tools = set() + + # Extract MCP server configuration for this agent type + if configurable.mcp_settings: + for server_name, server_config in configurable.mcp_settings["servers"].items(): + if ( + server_config["enabled_tools"] + and agent_type in server_config["add_to_agents"] + ): + mcp_servers[server_name] = { + k: v + for k, v in server_config.items() + if k in ("transport", "command", "args", "url", "env") + } + enabled_tools.update(server_config["enabled_tools"]) + + # Create and execute agent with MCP tools if available + if mcp_servers: + async with MultiServerMCPClient(mcp_servers) as client: + loaded_tools = [ + tool for tool in client.get_tools() if tool.name in enabled_tools + ] + default_tools + agent = create_agent(agent_type, agent_type, loaded_tools, agent_type) + return await _execute_agent_step(state, agent, agent_type) + else: + # Use default agent if no MCP servers are configured + return await _execute_agent_step(state, default_agent, agent_type) + + +async def researcher_node( + state: State, config: RunnableConfig +) -> Command[Literal["research_team"]]: """Researcher node that do research""" logger.info("Researcher node is researching.") - return _execute_agent_step(state, research_agent, "researcher") + return await _setup_and_execute_agent_step( + state, + config, + "researcher", + research_agent, + [web_search_tool, crawl_tool], + ) -def coder_node(state: State) -> Command[Literal["research_team"]]: +async def coder_node( + state: State, config: RunnableConfig +) -> Command[Literal["research_team"]]: """Coder node that do code analysis.""" logger.info("Coder node is coding.") - return _execute_agent_step(state, coder_agent, "coder") + return await _setup_and_execute_agent_step( + state, + config, + "coder", + coder_agent, + [python_repl_tool], + ) diff --git a/src/prompts/researcher.md b/src/prompts/researcher.md index 208a6cf..0c2be69 100644 --- a/src/prompts/researcher.md +++ b/src/prompts/researcher.md @@ -11,7 +11,7 @@ You are dedicated to conducting thorough investigations and providing comprehens 1. **Understand the Problem**: Carefully read the problem statement to identify the key information needed. 2. **Plan the Solution**: Determine the best approach to solve the problem using the available tools. 3. **Execute the Solution**: - - Use the **web_search_tool** to perform a search with the provided SEO keywords. + - Use the **web_search_tool** or other suitable tools to perform a search with the provided SEO keywords. - (Optional) Then use the **crawl_tool** to read markdown content from the necessary URLs. Only use the URLs from the search results or provided by the user. 4. **Synthesize Information**: - Combine the information gathered from the search results and the crawled content. @@ -24,7 +24,7 @@ You are dedicated to conducting thorough investigations and providing comprehens - Provide a structured response in markdown format. - Include the following sections: - **Problem Statement**: Restate the problem for clarity. - - **Search Results**: Summarize the key findings from the **web_search_tool** search. Track the sources of information but DO NOT include inline citations in the text. Include images if relevant. + - **Search Results**: Summarize the key findings from performed search. Track the sources of information but DO NOT include inline citations in the text. Include images if relevant. - **Crawled Content**: Summarize the key findings from the **crawl_tool**. Track the sources of information but DO NOT include inline citations in the text. Include images if relevant. - **Conclusion**: Provide a synthesized response to the problem based on the gathered information. - **References**: List all sources used with their complete URLs in link reference format at the end of the document. Make sure to include an empty line between each reference for better readability. Use this format for each reference: diff --git a/src/server/app.py b/src/server/app.py index 1482915..9b4b0bd 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -62,6 +62,7 @@ async def chat_stream(request: ChatRequest): request.max_step_num, request.auto_accepted_plan, request.interrupt_feedback, + request.mcp_settings, ), media_type="text/event-stream", ) @@ -74,6 +75,7 @@ async def _astream_workflow_generator( max_step_num: int, auto_accepted_plan: bool, interrupt_feedback: str, + mcp_settings: dict, ): input_ = { "messages": messages, @@ -95,6 +97,7 @@ async def _astream_workflow_generator( "thread_id": thread_id, "max_plan_iterations": max_plan_iterations, "max_step_num": max_step_num, + "mcp_settings": mcp_settings, }, stream_mode=["messages", "updates"], subgraphs=True, @@ -255,7 +258,7 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest): try: # Load tools from the MCP server using the utility function tools = await load_mcp_tools( - server_type=request.type, + server_type=request.transport, command=request.command, args=request.args, url=request.url, @@ -264,7 +267,7 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest): # Create the response with tools response = MCPServerMetadataResponse( - type=request.type, + transport=request.transport, command=request.command, args=request.args, url=request.url, diff --git a/src/server/chat_request.py b/src/server/chat_request.py index 970f52d..c8850f5 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -44,6 +44,9 @@ class ChatRequest(BaseModel): interrupt_feedback: Optional[str] = Field( None, description="Interrupt feedback from the user on the plan" ) + mcp_settings: Optional[dict] = Field( + None, description="MCP settings for the chat request" + ) class TTSRequest(BaseModel): diff --git a/src/server/mcp_request.py b/src/server/mcp_request.py index e82315f..ee02acb 100644 --- a/src/server/mcp_request.py +++ b/src/server/mcp_request.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field class MCPServerMetadataRequest(BaseModel): """Request model for MCP server metadata.""" - type: str = Field( + transport: str = Field( ..., description="The type of MCP server connection (stdio or sse)" ) command: Optional[str] = Field( @@ -27,7 +27,7 @@ class MCPServerMetadataRequest(BaseModel): class MCPServerMetadataResponse(BaseModel): """Response model for MCP server metadata.""" - type: str = Field( + transport: str = Field( ..., description="The type of MCP server connection (stdio or sse)" ) command: Optional[str] = Field( diff --git a/src/workflow.py b/src/workflow.py index 46437e4..ce88e65 100644 --- a/src/workflow.py +++ b/src/workflow.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import asyncio import logging from src.graph import build_graph @@ -22,13 +23,13 @@ logger = logging.getLogger(__name__) graph = build_graph() -def run_agent_workflow( +async def run_agent_workflow_async( user_input: str, debug: bool = False, max_plan_iterations: int = 1, max_step_num: int = 3, ): - """Run the agent workflow with the given user input. + """Run the agent workflow asynchronously with the given user input. Args: user_input: The user's query or request @@ -45,7 +46,7 @@ def run_agent_workflow( if debug: enable_debug_logging() - logger.info(f"Starting workflow with user input: {user_input}") + logger.info(f"Starting async workflow with user input: {user_input}") initial_state = { # Runtime Variables "messages": [{"role": "user", "content": user_input}], @@ -56,11 +57,24 @@ def run_agent_workflow( "thread_id": "default", "max_plan_iterations": max_plan_iterations, "max_step_num": max_step_num, + "mcp_settings": { + "servers": { + "mcp-github-trending": { + "transport": "stdio", + "command": "uvx", + "args": ["mcp-github-trending"], + "enabled_tools": ["get_github_trending_repositories"], + "add_to_agents": ["researcher"], + } + } + }, }, "recursion_limit": 100, } last_message_cnt = 0 - for s in graph.stream(input=initial_state, config=config, stream_mode="values"): + async for s in graph.astream( + input=initial_state, config=config, stream_mode="values" + ): try: if isinstance(s, dict) and "messages" in s: if len(s["messages"]) <= last_message_cnt: @@ -78,7 +92,7 @@ def run_agent_workflow( logger.error(f"Error processing stream output: {e}") print(f"Error processing output: {str(e)}") - logger.info("Workflow completed successfully") + logger.info("Async workflow completed successfully") if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index cf247f5..c608abe 100644 --- a/uv.lock +++ b/uv.lock @@ -323,6 +323,7 @@ dependencies = [ { name = "json-repair" }, { name = "langchain-community" }, { name = "langchain-experimental" }, + { name = "langchain-mcp-adapters" }, { name = "langchain-openai" }, { name = "langgraph" }, { name = "litellm" }, @@ -359,6 +360,7 @@ requires-dist = [ { name = "json-repair", specifier = ">=0.7.0" }, { name = "langchain-community", specifier = ">=0.3.19" }, { name = "langchain-experimental", specifier = ">=0.3.4" }, + { name = "langchain-mcp-adapters", specifier = ">=0.0.9" }, { name = "langchain-openai", specifier = ">=0.3.8" }, { name = "langgraph", specifier = ">=0.3.5" }, { name = "litellm", specifier = ">=1.63.11" }, @@ -823,6 +825,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/27/fe8caa4884611286b1f7d6c5cfd76e1fef188faaa946db4fde6daa1cd2cd/langchain_experimental-0.3.4-py3-none-any.whl", hash = "sha256:2e587306aea36b60fa5e5fc05dc7281bee9f60a806f0bf9d30916e0ee096af80", size = 209154 }, ] +[[package]] +name = "langchain-mcp-adapters" +version = "0.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "mcp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1a/48/dc5544f5b919b4ff9e736ec8db71217431c585c5c87acd3ab7558cc06cee/langchain_mcp_adapters-0.0.9.tar.gz", hash = "sha256:9ecd10fc420d98b3c14115bbca3174575e0a4ea29bd125ef39d11191a72ff1a1", size = 14827 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/24/3a4be149e8db15936533357f987b4b89c74c7f039427d6229679dbcc53b9/langchain_mcp_adapters-0.0.9-py3-none-any.whl", hash = "sha256:7c3dedd7830de826f418706c8a2fe388afcf8daf2037a1b39d1e065a5eacb082", size = 10065 }, +] + [[package]] name = "langchain-openai" version = "0.3.8"