feat: config max_search_results for search engine (#192)

* feat: implement UI

* feat: config max_search_results for search engine via api

---------

Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
DanielWalnut 2025-05-17 22:23:52 -07:00 committed by GitHub
parent c6bbc595c3
commit 8bbcdbe4de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 101 additions and 85 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .agents import research_agent, coder_agent
from .agents import create_agent
__all__ = ["research_agent", "coder_agent"]
__all__ = ["create_agent"]

View File

@ -4,12 +4,6 @@
from langgraph.prebuilt import create_react_agent
from src.prompts import apply_prompt_template
from src.tools import (
crawl_tool,
python_repl_tool,
web_search_tool,
)
from src.llms.llm import get_llm_by_type
from src.config.agents import AGENT_LLM_MAP
@ -23,10 +17,3 @@ def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template:
tools=tools,
prompt=lambda state: apply_prompt_template(prompt_template, state),
)
# Create agents using the factory function
research_agent = create_agent(
"researcher", "researcher", [web_search_tool, crawl_tool], "researcher"
)
coder_agent = create_agent("coder", "coder", [python_repl_tool], "coder")

View File

@ -1,7 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .tools import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
from .loader import load_yaml_config
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
@ -42,7 +42,6 @@ __all__ = [
# Other configurations
"TEAM_MEMBERS",
"TEAM_MEMBER_CONFIGRATIONS",
"SEARCH_MAX_RESULTS",
"SELECTED_SEARCH_ENGINE",
"SearchEngine",
"BUILT_IN_QUESTIONS",

View File

@ -14,6 +14,7 @@ class Configuration:
max_plan_iterations: int = 1 # Maximum number of plan iterations
max_step_num: int = 3 # Maximum number of steps in a plan
max_search_results: int = 3 # Maximum number of search results
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
@classmethod

View File

@ -17,4 +17,3 @@ class SearchEngine(enum.Enum):
# Tool configuration
SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
SEARCH_MAX_RESULTS = 3

View File

@ -12,12 +12,11 @@ from langchain_core.tools import tool
from langgraph.types import Command, interrupt
from langchain_mcp_adapters.client import MultiServerMCPClient
from src.agents.agents import coder_agent, research_agent, create_agent
from src.agents import create_agent
from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
web_search_tool,
get_web_search_tool,
python_repl_tool,
)
@ -29,7 +28,7 @@ from src.prompts.template import apply_prompt_template
from src.utils.json_utils import repair_json_output
from .types import State
from ..config import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
logger = logging.getLogger(__name__)
@ -45,13 +44,16 @@ def handoff_to_planner(
return
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
def background_investigation_node(
state: State, config: RunnableConfig
) -> Command[Literal["planner"]]:
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY:
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
{"query": query}
)
searched_content = LoggedTavilySearch(
max_results=configurable.max_search_results
).invoke({"query": query})
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
@ -63,7 +65,9 @@ def background_investigation_node(state: State) -> Command[Literal["planner"]]:
f"Tavily search returned malformed response: {searched_content}"
)
else:
background_investigation_results = web_search_tool.invoke(query)
background_investigation_results = get_web_search_tool(
configurable.max_search_results
).invoke(query)
return Command(
update={
"background_investigation_results": json.dumps(
@ -403,7 +407,6 @@ async def _setup_and_execute_agent_step(
state: State,
config: RunnableConfig,
agent_type: str,
default_agent,
default_tools: list,
) -> Command[Literal["research_team"]]:
"""Helper function to set up an agent with appropriate tools and execute a step.
@ -417,7 +420,6 @@ async def _setup_and_execute_agent_step(
state: The current state
config: The runnable config
agent_type: The type of agent ("researcher" or "coder")
default_agent: The default agent to use if no MCP servers are configured
default_tools: The default tools to add to the agent
Returns:
@ -455,8 +457,9 @@ async def _setup_and_execute_agent_step(
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
return await _execute_agent_step(state, agent, agent_type)
else:
# Use default agent if no MCP servers are configured
return await _execute_agent_step(state, default_agent, agent_type)
# Use default tools if no MCP servers are configured
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
return await _execute_agent_step(state, agent, agent_type)
async def researcher_node(
@ -464,12 +467,12 @@ async def researcher_node(
) -> Command[Literal["research_team"]]:
"""Researcher node that do research"""
logger.info("Researcher node is researching.")
configurable = Configuration.from_runnable_config(config)
return await _setup_and_execute_agent_step(
state,
config,
"researcher",
research_agent,
[web_search_tool, crawl_tool],
[get_web_search_tool(configurable.max_search_results), crawl_tool],
)
@ -482,6 +485,5 @@ async def coder_node(
state,
config,
"coder",
coder_agent,
[python_repl_tool],
)

View File

@ -61,6 +61,7 @@ async def chat_stream(request: ChatRequest):
thread_id,
request.max_plan_iterations,
request.max_step_num,
request.max_search_results,
request.auto_accepted_plan,
request.interrupt_feedback,
request.mcp_settings,
@ -75,6 +76,7 @@ async def _astream_workflow_generator(
thread_id: str,
max_plan_iterations: int,
max_step_num: int,
max_search_results: int,
auto_accepted_plan: bool,
interrupt_feedback: str,
mcp_settings: dict,
@ -101,6 +103,7 @@ async def _astream_workflow_generator(
"thread_id": thread_id,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
"mcp_settings": mcp_settings,
},
stream_mode=["messages", "updates"],

View File

@ -38,6 +38,9 @@ class ChatRequest(BaseModel):
max_step_num: Optional[int] = Field(
3, description="The maximum number of steps in a plan"
)
max_search_results: Optional[int] = Field(
3, description="The maximum number of search results"
)
auto_accepted_plan: Optional[bool] = Field(
False, description="Whether to automatically accept the plan"
)

View File

@ -5,28 +5,12 @@ import os
from .crawl import crawl_tool
from .python_repl import python_repl_tool
from .search import (
tavily_search_tool,
duckduckgo_search_tool,
brave_search_tool,
arxiv_search_tool,
)
from .search import get_web_search_tool
from .tts import VolcengineTTS
from src.config import SELECTED_SEARCH_ENGINE, SearchEngine
# Map search engine names to their respective tools
search_tool_mappings = {
SearchEngine.TAVILY.value: tavily_search_tool,
SearchEngine.DUCKDUCKGO.value: duckduckgo_search_tool,
SearchEngine.BRAVE_SEARCH.value: brave_search_tool,
SearchEngine.ARXIV.value: arxiv_search_tool,
}
web_search_tool = search_tool_mappings.get(SELECTED_SEARCH_ENGINE, tavily_search_tool)
__all__ = [
"crawl_tool",
"web_search_tool",
"python_repl_tool",
"get_web_search_tool",
"VolcengineTTS",
]

View File

@ -9,7 +9,7 @@ from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
from langchain_community.tools.arxiv import ArxivQueryRun
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
from src.config import SEARCH_MAX_RESULTS, SearchEngine
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
@ -18,44 +18,48 @@ from src.tools.decorators import create_logged_tool
logger = logging.getLogger(__name__)
# Create logged versions of the search tools
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
if os.getenv("SEARCH_API", "") == SearchEngine.TAVILY.value:
tavily_search_tool = LoggedTavilySearch(
name="web_search",
max_results=SEARCH_MAX_RESULTS,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
else:
tavily_search_tool = None
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
duckduckgo_search_tool = LoggedDuckDuckGoSearch(
name="web_search", max_results=SEARCH_MAX_RESULTS
)
LoggedBraveSearch = create_logged_tool(BraveSearch)
brave_search_tool = LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": SEARCH_MAX_RESULTS},
),
)
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
arxiv_search_tool = LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=SEARCH_MAX_RESULTS,
load_max_docs=SEARCH_MAX_RESULTS,
load_all_available_meta=True,
),
)
# Get the selected search tool
def get_web_search_tool(max_search_results: int):
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
return LoggedTavilySearch(
name="web_search",
max_results=max_search_results,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": max_search_results},
),
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
return LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=max_search_results,
load_max_docs=max_search_results,
load_all_available_meta=True,
),
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
if __name__ == "__main__":
results = LoggedDuckDuckGoSearch(
name="web_search", max_results=SEARCH_MAX_RESULTS, output_format="list"
name="web_search", max_results=3, output_format="list"
).invoke("cute panda")
print(json.dumps(results, indent=2, ensure_ascii=False))

View File

@ -32,6 +32,9 @@ const generalFormSchema = z.object({
maxStepNum: z.number().min(1, {
message: "Max step number must be at least 1.",
}),
maxSearchResults: z.number().min(1, {
message: "Max search results must be at least 1.",
}),
});
export const GeneralTab: Tab = ({
@ -143,6 +146,30 @@ export const GeneralTab: Tab = ({
</FormItem>
)}
/>
<FormField
control={form.control}
name="maxSearchResults"
render={({ field }) => (
<FormItem>
<FormLabel>Max search results</FormLabel>
<FormControl>
<Input
className="w-60"
type="number"
defaultValue={field.value}
min={1}
onChange={(event) =>
field.onChange(parseInt(event.target.value || "0"))
}
/>
</FormControl>
<FormDescription>
By default, each search step has 3 results.
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
</form>
</Form>
</main>

View File

@ -18,6 +18,7 @@ export async function* chatStream(
auto_accepted_plan: boolean;
max_plan_iterations: number;
max_step_num: number;
max_search_results?: number;
interrupt_feedback?: string;
enable_background_investigation: boolean;
mcp_settings?: {
@ -61,12 +62,14 @@ async function* chatReplayStream(
auto_accepted_plan: boolean;
max_plan_iterations: number;
max_step_num: number;
max_search_results?: number;
interrupt_feedback?: string;
} = {
thread_id: "__mock__",
auto_accepted_plan: false,
max_plan_iterations: 3,
max_step_num: 1,
max_search_results: 3,
interrupt_feedback: undefined,
},
options: { abortSignal?: AbortSignal } = {},
@ -157,6 +160,7 @@ export async function fetchReplayTitle() {
auto_accepted_plan: false,
max_plan_iterations: 3,
max_step_num: 1,
max_search_results: 3,
},
{},
);

View File

@ -13,6 +13,7 @@ const DEFAULT_SETTINGS: SettingsState = {
enableBackgroundInvestigation: false,
maxPlanIterations: 1,
maxStepNum: 3,
maxSearchResults: 3,
},
mcp: {
servers: [],
@ -25,6 +26,7 @@ export type SettingsState = {
enableBackgroundInvestigation: boolean;
maxPlanIterations: number;
maxStepNum: number;
maxSearchResults: number;
};
mcp: {
servers: MCPServerMetadata[];

View File

@ -104,6 +104,7 @@ export async function sendMessage(
settings.enableBackgroundInvestigation ?? true,
max_plan_iterations: settings.maxPlanIterations,
max_step_num: settings.maxStepNum,
max_search_results: settings.maxSearchResults,
mcp_settings: settings.mcpSettings,
},
options,