diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 6f95f10..9613ee6 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -50,9 +50,24 @@ def background_investigation_node( logger.info("background investigation node is running.") configurable = Configuration.from_runnable_config(config) query = state["messages"][-1].content - background_investigation_results = get_web_search_tool( - configurable.max_search_results - ).invoke(query) + if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: + searched_content = LoggedTavilySearch( + max_results=configurable.max_search_results + ).invoke(query) + background_investigation_results = None + if isinstance(searched_content, list): + background_investigation_results = [ + {"title": elem["title"], "content": elem["content"]} + for elem in searched_content + ] + else: + logger.error( + f"Tavily search returned malformed response: {searched_content}" + ) + else: + background_investigation_results = get_web_search_tool( + configurable.max_search_results + ).invoke(query) return Command( update={ "background_investigation_results": json.dumps( diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 978acaa..14c14a8 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -68,7 +68,7 @@ def mock_web_search_tool(): yield mock -@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY, "other"]) +@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"]) def test_background_investigation_node_tavily( mock_state, mock_tavily_search, @@ -93,10 +93,8 @@ def test_background_investigation_node_tavily( results = json.loads(update["background_investigation_results"]) assert isinstance(results, list) - if search_engine == SearchEngine.TAVILY: - mock_tavily_search.return_value.invoke.assert_called_once_with( - {"query": "test query"} - ) + if search_engine == SearchEngine.TAVILY.value: + mock_tavily_search.return_value.invoke.assert_called_once_with("test query") assert len(results) == 2 assert results[0]["title"] == "Test Title 1" assert results[0]["content"] == "Test Content 1" @@ -111,7 +109,7 @@ def test_background_investigation_node_malformed_response( mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config ): """Test background_investigation_node with malformed Tavily response""" - with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY): + with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value): # Mock a malformed response mock_tavily_search.return_value.invoke.return_value = "invalid response"