mirror of
https://git.mirrors.martin98.com/https://github.com/bytedance/deer-flow
synced 2025-08-16 17:06:01 +08:00
feat: config max_search_results for search engine (#192)
* feat: implement UI * feat: config max_search_results for search engine via api --------- Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
parent
c6bbc595c3
commit
8bbcdbe4de
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .agents import research_agent, coder_agent
|
||||
from .agents import create_agent
|
||||
|
||||
__all__ = ["research_agent", "coder_agent"]
|
||||
__all__ = ["create_agent"]
|
||||
|
@ -4,12 +4,6 @@
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
from src.prompts import apply_prompt_template
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
python_repl_tool,
|
||||
web_search_tool,
|
||||
)
|
||||
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
|
||||
@ -23,10 +17,3 @@ def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template:
|
||||
tools=tools,
|
||||
prompt=lambda state: apply_prompt_template(prompt_template, state),
|
||||
)
|
||||
|
||||
|
||||
# Create agents using the factory function
|
||||
research_agent = create_agent(
|
||||
"researcher", "researcher", [web_search_tool, crawl_tool], "researcher"
|
||||
)
|
||||
coder_agent = create_agent("coder", "coder", [python_repl_tool], "coder")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .tools import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from .loader import load_yaml_config
|
||||
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
|
||||
|
||||
@ -42,7 +42,6 @@ __all__ = [
|
||||
# Other configurations
|
||||
"TEAM_MEMBERS",
|
||||
"TEAM_MEMBER_CONFIGRATIONS",
|
||||
"SEARCH_MAX_RESULTS",
|
||||
"SELECTED_SEARCH_ENGINE",
|
||||
"SearchEngine",
|
||||
"BUILT_IN_QUESTIONS",
|
||||
|
@ -14,6 +14,7 @@ class Configuration:
|
||||
|
||||
max_plan_iterations: int = 1 # Maximum number of plan iterations
|
||||
max_step_num: int = 3 # Maximum number of steps in a plan
|
||||
max_search_results: int = 3 # Maximum number of search results
|
||||
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
|
||||
|
||||
@classmethod
|
||||
|
@ -17,4 +17,3 @@ class SearchEngine(enum.Enum):
|
||||
|
||||
# Tool configuration
|
||||
SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
|
||||
SEARCH_MAX_RESULTS = 3
|
||||
|
@ -12,12 +12,11 @@ from langchain_core.tools import tool
|
||||
from langgraph.types import Command, interrupt
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
|
||||
from src.agents.agents import coder_agent, research_agent, create_agent
|
||||
|
||||
from src.agents import create_agent
|
||||
from src.tools.search import LoggedTavilySearch
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
web_search_tool,
|
||||
get_web_search_tool,
|
||||
python_repl_tool,
|
||||
)
|
||||
|
||||
@ -29,7 +28,7 @@ from src.prompts.template import apply_prompt_template
|
||||
from src.utils.json_utils import repair_json_output
|
||||
|
||||
from .types import State
|
||||
from ..config import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -45,13 +44,16 @@ def handoff_to_planner(
|
||||
return
|
||||
|
||||
|
||||
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
|
||||
def background_investigation_node(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Command[Literal["planner"]]:
|
||||
logger.info("background investigation node is running.")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
query = state["messages"][-1].content
|
||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY:
|
||||
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
|
||||
{"query": query}
|
||||
)
|
||||
searched_content = LoggedTavilySearch(
|
||||
max_results=configurable.max_search_results
|
||||
).invoke({"query": query})
|
||||
background_investigation_results = None
|
||||
if isinstance(searched_content, list):
|
||||
background_investigation_results = [
|
||||
@ -63,7 +65,9 @@ def background_investigation_node(state: State) -> Command[Literal["planner"]]:
|
||||
f"Tavily search returned malformed response: {searched_content}"
|
||||
)
|
||||
else:
|
||||
background_investigation_results = web_search_tool.invoke(query)
|
||||
background_investigation_results = get_web_search_tool(
|
||||
configurable.max_search_results
|
||||
).invoke(query)
|
||||
return Command(
|
||||
update={
|
||||
"background_investigation_results": json.dumps(
|
||||
@ -403,7 +407,6 @@ async def _setup_and_execute_agent_step(
|
||||
state: State,
|
||||
config: RunnableConfig,
|
||||
agent_type: str,
|
||||
default_agent,
|
||||
default_tools: list,
|
||||
) -> Command[Literal["research_team"]]:
|
||||
"""Helper function to set up an agent with appropriate tools and execute a step.
|
||||
@ -417,7 +420,6 @@ async def _setup_and_execute_agent_step(
|
||||
state: The current state
|
||||
config: The runnable config
|
||||
agent_type: The type of agent ("researcher" or "coder")
|
||||
default_agent: The default agent to use if no MCP servers are configured
|
||||
default_tools: The default tools to add to the agent
|
||||
|
||||
Returns:
|
||||
@ -455,8 +457,9 @@ async def _setup_and_execute_agent_step(
|
||||
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
|
||||
return await _execute_agent_step(state, agent, agent_type)
|
||||
else:
|
||||
# Use default agent if no MCP servers are configured
|
||||
return await _execute_agent_step(state, default_agent, agent_type)
|
||||
# Use default tools if no MCP servers are configured
|
||||
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
|
||||
return await _execute_agent_step(state, agent, agent_type)
|
||||
|
||||
|
||||
async def researcher_node(
|
||||
@ -464,12 +467,12 @@ async def researcher_node(
|
||||
) -> Command[Literal["research_team"]]:
|
||||
"""Researcher node that do research"""
|
||||
logger.info("Researcher node is researching.")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
return await _setup_and_execute_agent_step(
|
||||
state,
|
||||
config,
|
||||
"researcher",
|
||||
research_agent,
|
||||
[web_search_tool, crawl_tool],
|
||||
[get_web_search_tool(configurable.max_search_results), crawl_tool],
|
||||
)
|
||||
|
||||
|
||||
@ -482,6 +485,5 @@ async def coder_node(
|
||||
state,
|
||||
config,
|
||||
"coder",
|
||||
coder_agent,
|
||||
[python_repl_tool],
|
||||
)
|
||||
|
@ -61,6 +61,7 @@ async def chat_stream(request: ChatRequest):
|
||||
thread_id,
|
||||
request.max_plan_iterations,
|
||||
request.max_step_num,
|
||||
request.max_search_results,
|
||||
request.auto_accepted_plan,
|
||||
request.interrupt_feedback,
|
||||
request.mcp_settings,
|
||||
@ -75,6 +76,7 @@ async def _astream_workflow_generator(
|
||||
thread_id: str,
|
||||
max_plan_iterations: int,
|
||||
max_step_num: int,
|
||||
max_search_results: int,
|
||||
auto_accepted_plan: bool,
|
||||
interrupt_feedback: str,
|
||||
mcp_settings: dict,
|
||||
@ -101,6 +103,7 @@ async def _astream_workflow_generator(
|
||||
"thread_id": thread_id,
|
||||
"max_plan_iterations": max_plan_iterations,
|
||||
"max_step_num": max_step_num,
|
||||
"max_search_results": max_search_results,
|
||||
"mcp_settings": mcp_settings,
|
||||
},
|
||||
stream_mode=["messages", "updates"],
|
||||
|
@ -38,6 +38,9 @@ class ChatRequest(BaseModel):
|
||||
max_step_num: Optional[int] = Field(
|
||||
3, description="The maximum number of steps in a plan"
|
||||
)
|
||||
max_search_results: Optional[int] = Field(
|
||||
3, description="The maximum number of search results"
|
||||
)
|
||||
auto_accepted_plan: Optional[bool] = Field(
|
||||
False, description="Whether to automatically accept the plan"
|
||||
)
|
||||
|
@ -5,28 +5,12 @@ import os
|
||||
|
||||
from .crawl import crawl_tool
|
||||
from .python_repl import python_repl_tool
|
||||
from .search import (
|
||||
tavily_search_tool,
|
||||
duckduckgo_search_tool,
|
||||
brave_search_tool,
|
||||
arxiv_search_tool,
|
||||
)
|
||||
from .search import get_web_search_tool
|
||||
from .tts import VolcengineTTS
|
||||
from src.config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
|
||||
# Map search engine names to their respective tools
|
||||
search_tool_mappings = {
|
||||
SearchEngine.TAVILY.value: tavily_search_tool,
|
||||
SearchEngine.DUCKDUCKGO.value: duckduckgo_search_tool,
|
||||
SearchEngine.BRAVE_SEARCH.value: brave_search_tool,
|
||||
SearchEngine.ARXIV.value: arxiv_search_tool,
|
||||
}
|
||||
|
||||
web_search_tool = search_tool_mappings.get(SELECTED_SEARCH_ENGINE, tavily_search_tool)
|
||||
|
||||
__all__ = [
|
||||
"crawl_tool",
|
||||
"web_search_tool",
|
||||
"python_repl_tool",
|
||||
"get_web_search_tool",
|
||||
"VolcengineTTS",
|
||||
]
|
||||
|
@ -9,7 +9,7 @@ from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
|
||||
from langchain_community.tools.arxiv import ArxivQueryRun
|
||||
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
|
||||
|
||||
from src.config import SEARCH_MAX_RESULTS, SearchEngine
|
||||
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchResultsWithImages,
|
||||
)
|
||||
@ -18,44 +18,48 @@ from src.tools.decorators import create_logged_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create logged versions of the search tools
|
||||
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
|
||||
if os.getenv("SEARCH_API", "") == SearchEngine.TAVILY.value:
|
||||
tavily_search_tool = LoggedTavilySearch(
|
||||
name="web_search",
|
||||
max_results=SEARCH_MAX_RESULTS,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
else:
|
||||
tavily_search_tool = None
|
||||
|
||||
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
|
||||
duckduckgo_search_tool = LoggedDuckDuckGoSearch(
|
||||
name="web_search", max_results=SEARCH_MAX_RESULTS
|
||||
)
|
||||
|
||||
LoggedBraveSearch = create_logged_tool(BraveSearch)
|
||||
brave_search_tool = LoggedBraveSearch(
|
||||
name="web_search",
|
||||
search_wrapper=BraveSearchWrapper(
|
||||
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
|
||||
search_kwargs={"count": SEARCH_MAX_RESULTS},
|
||||
),
|
||||
)
|
||||
|
||||
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
|
||||
arxiv_search_tool = LoggedArxivSearch(
|
||||
name="web_search",
|
||||
api_wrapper=ArxivAPIWrapper(
|
||||
top_k_results=SEARCH_MAX_RESULTS,
|
||||
load_max_docs=SEARCH_MAX_RESULTS,
|
||||
load_all_available_meta=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Get the selected search tool
|
||||
def get_web_search_tool(max_search_results: int):
|
||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||
return LoggedTavilySearch(
|
||||
name="web_search",
|
||||
max_results=max_search_results,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
|
||||
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
|
||||
return LoggedBraveSearch(
|
||||
name="web_search",
|
||||
search_wrapper=BraveSearchWrapper(
|
||||
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
|
||||
search_kwargs={"count": max_search_results},
|
||||
),
|
||||
)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
|
||||
return LoggedArxivSearch(
|
||||
name="web_search",
|
||||
api_wrapper=ArxivAPIWrapper(
|
||||
top_k_results=max_search_results,
|
||||
load_max_docs=max_search_results,
|
||||
load_all_available_meta=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = LoggedDuckDuckGoSearch(
|
||||
name="web_search", max_results=SEARCH_MAX_RESULTS, output_format="list"
|
||||
name="web_search", max_results=3, output_format="list"
|
||||
).invoke("cute panda")
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
@ -32,6 +32,9 @@ const generalFormSchema = z.object({
|
||||
maxStepNum: z.number().min(1, {
|
||||
message: "Max step number must be at least 1.",
|
||||
}),
|
||||
maxSearchResults: z.number().min(1, {
|
||||
message: "Max search results must be at least 1.",
|
||||
}),
|
||||
});
|
||||
|
||||
export const GeneralTab: Tab = ({
|
||||
@ -143,6 +146,30 @@ export const GeneralTab: Tab = ({
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="maxSearchResults"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Max search results</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
className="w-60"
|
||||
type="number"
|
||||
defaultValue={field.value}
|
||||
min={1}
|
||||
onChange={(event) =>
|
||||
field.onChange(parseInt(event.target.value || "0"))
|
||||
}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
By default, each search step has 3 results.
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</form>
|
||||
</Form>
|
||||
</main>
|
||||
|
@ -18,6 +18,7 @@ export async function* chatStream(
|
||||
auto_accepted_plan: boolean;
|
||||
max_plan_iterations: number;
|
||||
max_step_num: number;
|
||||
max_search_results?: number;
|
||||
interrupt_feedback?: string;
|
||||
enable_background_investigation: boolean;
|
||||
mcp_settings?: {
|
||||
@ -61,12 +62,14 @@ async function* chatReplayStream(
|
||||
auto_accepted_plan: boolean;
|
||||
max_plan_iterations: number;
|
||||
max_step_num: number;
|
||||
max_search_results?: number;
|
||||
interrupt_feedback?: string;
|
||||
} = {
|
||||
thread_id: "__mock__",
|
||||
auto_accepted_plan: false,
|
||||
max_plan_iterations: 3,
|
||||
max_step_num: 1,
|
||||
max_search_results: 3,
|
||||
interrupt_feedback: undefined,
|
||||
},
|
||||
options: { abortSignal?: AbortSignal } = {},
|
||||
@ -157,6 +160,7 @@ export async function fetchReplayTitle() {
|
||||
auto_accepted_plan: false,
|
||||
max_plan_iterations: 3,
|
||||
max_step_num: 1,
|
||||
max_search_results: 3,
|
||||
},
|
||||
{},
|
||||
);
|
||||
|
@ -13,6 +13,7 @@ const DEFAULT_SETTINGS: SettingsState = {
|
||||
enableBackgroundInvestigation: false,
|
||||
maxPlanIterations: 1,
|
||||
maxStepNum: 3,
|
||||
maxSearchResults: 3,
|
||||
},
|
||||
mcp: {
|
||||
servers: [],
|
||||
@ -25,6 +26,7 @@ export type SettingsState = {
|
||||
enableBackgroundInvestigation: boolean;
|
||||
maxPlanIterations: number;
|
||||
maxStepNum: number;
|
||||
maxSearchResults: number;
|
||||
};
|
||||
mcp: {
|
||||
servers: MCPServerMetadata[];
|
||||
|
@ -104,6 +104,7 @@ export async function sendMessage(
|
||||
settings.enableBackgroundInvestigation ?? true,
|
||||
max_plan_iterations: settings.maxPlanIterations,
|
||||
max_step_num: settings.maxStepNum,
|
||||
max_search_results: settings.maxSearchResults,
|
||||
mcp_settings: settings.mcpSettings,
|
||||
},
|
||||
options,
|
||||
|
Loading…
x
Reference in New Issue
Block a user