diff --git a/.env.example b/.env.example index c6015e9..cacd6a1 100644 --- a/.env.example +++ b/.env.example @@ -2,9 +2,10 @@ DEBUG=True APP_ENV=development -# Search Engine +# Search Engine, Supported values: tavily (recommended), duckduckgo, brave_search, arxiv SEARCH_API=tavily TAVILY_API_KEY=tvly-xxx +# BRAVE_SEARCH_API_KEY=xxx # Required only if SEARCH_API is brave_search # JINA_API_KEY=jina_xxx # Optional, default is None # Optional, volcengine TTS for generating podcast diff --git a/src/graph/nodes.py b/src/graph/nodes.py index a87bf56..b0b8cad 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -28,7 +28,7 @@ from src.prompts.template import apply_prompt_template from src.utils.json_utils import repair_json_output from .types import State -from ..config import SEARCH_MAX_RESULTS +from ..config import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine logger = logging.getLogger(__name__) @@ -45,19 +45,24 @@ def handoff_to_planner( def background_investigation_node(state: State) -> Command[Literal["planner"]]: - logger.info("background investigation node is running.") - searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke( - {"query": state["messages"][-1].content} - ) - background_investigation_results = None - if isinstance(searched_content, list): - background_investigation_results = [ - {"title": elem["title"], "content": elem["content"]} - for elem in searched_content - ] + query = state["messages"][-1].content + if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY: + searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke( + {"query": query} + ) + background_investigation_results = None + if isinstance(searched_content, list): + background_investigation_results = [ + {"title": elem["title"], "content": elem["content"]} + for elem in searched_content + ] + else: + logger.error( + f"Tavily search returned malformed response: {searched_content}" + ) else: - logger.error(f"Tavily search returned malformed response: {searched_content}") + background_investigation_results = web_search_tool.invoke(query) return Command( update={ "background_investigation_results": json.dumps( diff --git a/src/tools/search.py b/src/tools/search.py index caa70a0..e4c92ad 100644 --- a/src/tools/search.py +++ b/src/tools/search.py @@ -14,11 +14,10 @@ from src.tools.tavily_search.tavily_search_results_with_images import ( TavilySearchResultsWithImages, ) -from .decorators import create_logged_tool +from src.tools.decorators import create_logged_tool logger = logging.getLogger(__name__) - LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages) tavily_search_tool = LoggedTavilySearch( name="web_search", @@ -53,5 +52,7 @@ arxiv_search_tool = LoggedArxivSearch( ) if __name__ == "__main__": - results = tavily_search_tool.invoke("cute panda") + results = LoggedDuckDuckGoSearch( + name="web_search", max_results=SEARCH_MAX_RESULTS, output_format="list" + ).invoke("cute panda") print(json.dumps(results, indent=2, ensure_ascii=False))