From 8bbcdbe4de85e18dd93b5d7355c594976bf6f6a8 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Sat, 17 May 2025 22:23:52 -0700 Subject: [PATCH] feat: config max_search_results for search engine (#192) * feat: implement UI * feat: config max_search_results for search engine via api --------- Co-authored-by: Henry Li --- src/agents/__init__.py | 4 +- src/agents/agents.py | 13 ----- src/config/__init__.py | 3 +- src/config/configuration.py | 1 + src/config/tools.py | 1 - src/graph/nodes.py | 34 +++++------ src/server/app.py | 3 + src/server/chat_request.py | 3 + src/tools/__init__.py | 20 +------ src/tools/search.py | 70 ++++++++++++----------- web/src/app/settings/tabs/general-tab.tsx | 27 +++++++++ web/src/core/api/chat.ts | 4 ++ web/src/core/store/settings-store.ts | 2 + web/src/core/store/store.ts | 1 + 14 files changed, 101 insertions(+), 85 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index c235a05..76ce56c 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT -from .agents import research_agent, coder_agent +from .agents import create_agent -__all__ = ["research_agent", "coder_agent"] +__all__ = ["create_agent"] diff --git a/src/agents/agents.py b/src/agents/agents.py index 2a4d330..e8fb3a8 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -4,12 +4,6 @@ from langgraph.prebuilt import create_react_agent from src.prompts import apply_prompt_template -from src.tools import ( - crawl_tool, - python_repl_tool, - web_search_tool, -) - from src.llms.llm import get_llm_by_type from src.config.agents import AGENT_LLM_MAP @@ -23,10 +17,3 @@ def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template: tools=tools, prompt=lambda state: apply_prompt_template(prompt_template, state), ) - - -# Create agents using the factory function -research_agent = create_agent( - "researcher", "researcher", [web_search_tool, crawl_tool], "researcher" -) -coder_agent = create_agent("coder", "coder", [python_repl_tool], "coder") diff --git a/src/config/__init__.py b/src/config/__init__.py index c4639c1..4e6d178 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT -from .tools import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine +from .tools import SELECTED_SEARCH_ENGINE, SearchEngine from .loader import load_yaml_config from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN @@ -42,7 +42,6 @@ __all__ = [ # Other configurations "TEAM_MEMBERS", "TEAM_MEMBER_CONFIGRATIONS", - "SEARCH_MAX_RESULTS", "SELECTED_SEARCH_ENGINE", "SearchEngine", "BUILT_IN_QUESTIONS", diff --git a/src/config/configuration.py b/src/config/configuration.py index 24cb53a..42e4af4 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -14,6 +14,7 @@ class Configuration: max_plan_iterations: int = 1 # Maximum number of plan iterations max_step_num: int = 3 # Maximum number of steps in a plan + max_search_results: int = 3 # Maximum number of search results mcp_settings: dict = None # MCP settings, including dynamic loaded tools @classmethod diff --git a/src/config/tools.py b/src/config/tools.py index 37ee9d6..941de2b 100644 --- a/src/config/tools.py +++ b/src/config/tools.py @@ -17,4 +17,3 @@ class SearchEngine(enum.Enum): # Tool configuration SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value) -SEARCH_MAX_RESULTS = 3 diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 64d3536..892faef 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -12,12 +12,11 @@ from langchain_core.tools import tool from langgraph.types import Command, interrupt from langchain_mcp_adapters.client import MultiServerMCPClient -from src.agents.agents import coder_agent, research_agent, create_agent - +from src.agents import create_agent from src.tools.search import LoggedTavilySearch from src.tools import ( crawl_tool, - web_search_tool, + get_web_search_tool, python_repl_tool, ) @@ -29,7 +28,7 @@ from src.prompts.template import apply_prompt_template from src.utils.json_utils import repair_json_output from .types import State -from ..config import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine +from ..config import SELECTED_SEARCH_ENGINE, SearchEngine logger = logging.getLogger(__name__) @@ -45,13 +44,16 @@ def handoff_to_planner( return -def background_investigation_node(state: State) -> Command[Literal["planner"]]: +def background_investigation_node( + state: State, config: RunnableConfig +) -> Command[Literal["planner"]]: logger.info("background investigation node is running.") + configurable = Configuration.from_runnable_config(config) query = state["messages"][-1].content if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY: - searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke( - {"query": query} - ) + searched_content = LoggedTavilySearch( + max_results=configurable.max_search_results + ).invoke({"query": query}) background_investigation_results = None if isinstance(searched_content, list): background_investigation_results = [ @@ -63,7 +65,9 @@ def background_investigation_node(state: State) -> Command[Literal["planner"]]: f"Tavily search returned malformed response: {searched_content}" ) else: - background_investigation_results = web_search_tool.invoke(query) + background_investigation_results = get_web_search_tool( + configurable.max_search_results + ).invoke(query) return Command( update={ "background_investigation_results": json.dumps( @@ -403,7 +407,6 @@ async def _setup_and_execute_agent_step( state: State, config: RunnableConfig, agent_type: str, - default_agent, default_tools: list, ) -> Command[Literal["research_team"]]: """Helper function to set up an agent with appropriate tools and execute a step. @@ -417,7 +420,6 @@ async def _setup_and_execute_agent_step( state: The current state config: The runnable config agent_type: The type of agent ("researcher" or "coder") - default_agent: The default agent to use if no MCP servers are configured default_tools: The default tools to add to the agent Returns: @@ -455,8 +457,9 @@ async def _setup_and_execute_agent_step( agent = create_agent(agent_type, agent_type, loaded_tools, agent_type) return await _execute_agent_step(state, agent, agent_type) else: - # Use default agent if no MCP servers are configured - return await _execute_agent_step(state, default_agent, agent_type) + # Use default tools if no MCP servers are configured + agent = create_agent(agent_type, agent_type, default_tools, agent_type) + return await _execute_agent_step(state, agent, agent_type) async def researcher_node( @@ -464,12 +467,12 @@ async def researcher_node( ) -> Command[Literal["research_team"]]: """Researcher node that do research""" logger.info("Researcher node is researching.") + configurable = Configuration.from_runnable_config(config) return await _setup_and_execute_agent_step( state, config, "researcher", - research_agent, - [web_search_tool, crawl_tool], + [get_web_search_tool(configurable.max_search_results), crawl_tool], ) @@ -482,6 +485,5 @@ async def coder_node( state, config, "coder", - coder_agent, [python_repl_tool], ) diff --git a/src/server/app.py b/src/server/app.py index 9b91685..0937a2f 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -61,6 +61,7 @@ async def chat_stream(request: ChatRequest): thread_id, request.max_plan_iterations, request.max_step_num, + request.max_search_results, request.auto_accepted_plan, request.interrupt_feedback, request.mcp_settings, @@ -75,6 +76,7 @@ async def _astream_workflow_generator( thread_id: str, max_plan_iterations: int, max_step_num: int, + max_search_results: int, auto_accepted_plan: bool, interrupt_feedback: str, mcp_settings: dict, @@ -101,6 +103,7 @@ async def _astream_workflow_generator( "thread_id": thread_id, "max_plan_iterations": max_plan_iterations, "max_step_num": max_step_num, + "max_search_results": max_search_results, "mcp_settings": mcp_settings, }, stream_mode=["messages", "updates"], diff --git a/src/server/chat_request.py b/src/server/chat_request.py index 6626fe1..8e6d786 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -38,6 +38,9 @@ class ChatRequest(BaseModel): max_step_num: Optional[int] = Field( 3, description="The maximum number of steps in a plan" ) + max_search_results: Optional[int] = Field( + 3, description="The maximum number of search results" + ) auto_accepted_plan: Optional[bool] = Field( False, description="Whether to automatically accept the plan" ) diff --git a/src/tools/__init__.py b/src/tools/__init__.py index 7854f94..fb89121 100644 --- a/src/tools/__init__.py +++ b/src/tools/__init__.py @@ -5,28 +5,12 @@ import os from .crawl import crawl_tool from .python_repl import python_repl_tool -from .search import ( - tavily_search_tool, - duckduckgo_search_tool, - brave_search_tool, - arxiv_search_tool, -) +from .search import get_web_search_tool from .tts import VolcengineTTS -from src.config import SELECTED_SEARCH_ENGINE, SearchEngine - -# Map search engine names to their respective tools -search_tool_mappings = { - SearchEngine.TAVILY.value: tavily_search_tool, - SearchEngine.DUCKDUCKGO.value: duckduckgo_search_tool, - SearchEngine.BRAVE_SEARCH.value: brave_search_tool, - SearchEngine.ARXIV.value: arxiv_search_tool, -} - -web_search_tool = search_tool_mappings.get(SELECTED_SEARCH_ENGINE, tavily_search_tool) __all__ = [ "crawl_tool", - "web_search_tool", "python_repl_tool", + "get_web_search_tool", "VolcengineTTS", ] diff --git a/src/tools/search.py b/src/tools/search.py index 992a6b5..88dd5eb 100644 --- a/src/tools/search.py +++ b/src/tools/search.py @@ -9,7 +9,7 @@ from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults from langchain_community.tools.arxiv import ArxivQueryRun from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper -from src.config import SEARCH_MAX_RESULTS, SearchEngine +from src.config import SearchEngine, SELECTED_SEARCH_ENGINE from src.tools.tavily_search.tavily_search_results_with_images import ( TavilySearchResultsWithImages, ) @@ -18,44 +18,48 @@ from src.tools.decorators import create_logged_tool logger = logging.getLogger(__name__) +# Create logged versions of the search tools LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages) -if os.getenv("SEARCH_API", "") == SearchEngine.TAVILY.value: - tavily_search_tool = LoggedTavilySearch( - name="web_search", - max_results=SEARCH_MAX_RESULTS, - include_raw_content=True, - include_images=True, - include_image_descriptions=True, - ) -else: - tavily_search_tool = None - LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults) -duckduckgo_search_tool = LoggedDuckDuckGoSearch( - name="web_search", max_results=SEARCH_MAX_RESULTS -) - LoggedBraveSearch = create_logged_tool(BraveSearch) -brave_search_tool = LoggedBraveSearch( - name="web_search", - search_wrapper=BraveSearchWrapper( - api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""), - search_kwargs={"count": SEARCH_MAX_RESULTS}, - ), -) - LoggedArxivSearch = create_logged_tool(ArxivQueryRun) -arxiv_search_tool = LoggedArxivSearch( - name="web_search", - api_wrapper=ArxivAPIWrapper( - top_k_results=SEARCH_MAX_RESULTS, - load_max_docs=SEARCH_MAX_RESULTS, - load_all_available_meta=True, - ), -) + + +# Get the selected search tool +def get_web_search_tool(max_search_results: int): + if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: + return LoggedTavilySearch( + name="web_search", + max_results=max_search_results, + include_raw_content=True, + include_images=True, + include_image_descriptions=True, + ) + elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value: + return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results) + elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value: + return LoggedBraveSearch( + name="web_search", + search_wrapper=BraveSearchWrapper( + api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""), + search_kwargs={"count": max_search_results}, + ), + ) + elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value: + return LoggedArxivSearch( + name="web_search", + api_wrapper=ArxivAPIWrapper( + top_k_results=max_search_results, + load_max_docs=max_search_results, + load_all_available_meta=True, + ), + ) + else: + raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}") + if __name__ == "__main__": results = LoggedDuckDuckGoSearch( - name="web_search", max_results=SEARCH_MAX_RESULTS, output_format="list" + name="web_search", max_results=3, output_format="list" ).invoke("cute panda") print(json.dumps(results, indent=2, ensure_ascii=False)) diff --git a/web/src/app/settings/tabs/general-tab.tsx b/web/src/app/settings/tabs/general-tab.tsx index ece0191..8a4fad5 100644 --- a/web/src/app/settings/tabs/general-tab.tsx +++ b/web/src/app/settings/tabs/general-tab.tsx @@ -32,6 +32,9 @@ const generalFormSchema = z.object({ maxStepNum: z.number().min(1, { message: "Max step number must be at least 1.", }), + maxSearchResults: z.number().min(1, { + message: "Max search results must be at least 1.", + }), }); export const GeneralTab: Tab = ({ @@ -143,6 +146,30 @@ export const GeneralTab: Tab = ({ )} /> + ( + + Max search results + + + field.onChange(parseInt(event.target.value || "0")) + } + /> + + + By default, each search step has 3 results. + + + + )} + /> diff --git a/web/src/core/api/chat.ts b/web/src/core/api/chat.ts index d8dca10..422f964 100644 --- a/web/src/core/api/chat.ts +++ b/web/src/core/api/chat.ts @@ -18,6 +18,7 @@ export async function* chatStream( auto_accepted_plan: boolean; max_plan_iterations: number; max_step_num: number; + max_search_results?: number; interrupt_feedback?: string; enable_background_investigation: boolean; mcp_settings?: { @@ -61,12 +62,14 @@ async function* chatReplayStream( auto_accepted_plan: boolean; max_plan_iterations: number; max_step_num: number; + max_search_results?: number; interrupt_feedback?: string; } = { thread_id: "__mock__", auto_accepted_plan: false, max_plan_iterations: 3, max_step_num: 1, + max_search_results: 3, interrupt_feedback: undefined, }, options: { abortSignal?: AbortSignal } = {}, @@ -157,6 +160,7 @@ export async function fetchReplayTitle() { auto_accepted_plan: false, max_plan_iterations: 3, max_step_num: 1, + max_search_results: 3, }, {}, ); diff --git a/web/src/core/store/settings-store.ts b/web/src/core/store/settings-store.ts index d568f9c..66c76e9 100644 --- a/web/src/core/store/settings-store.ts +++ b/web/src/core/store/settings-store.ts @@ -13,6 +13,7 @@ const DEFAULT_SETTINGS: SettingsState = { enableBackgroundInvestigation: false, maxPlanIterations: 1, maxStepNum: 3, + maxSearchResults: 3, }, mcp: { servers: [], @@ -25,6 +26,7 @@ export type SettingsState = { enableBackgroundInvestigation: boolean; maxPlanIterations: number; maxStepNum: number; + maxSearchResults: number; }; mcp: { servers: MCPServerMetadata[]; diff --git a/web/src/core/store/store.ts b/web/src/core/store/store.ts index ff4e7b1..21a97cc 100644 --- a/web/src/core/store/store.ts +++ b/web/src/core/store/store.ts @@ -104,6 +104,7 @@ export async function sendMessage( settings.enableBackgroundInvestigation ?? true, max_plan_iterations: settings.maxPlanIterations, max_step_num: settings.maxStepNum, + max_search_results: settings.maxSearchResults, mcp_settings: settings.mcpSettings, }, options,