fix: fix unittes & background investigation search logic (#247)

This commit is contained in:
DanielWalnut 2025-05-27 23:05:34 -07:00 committed by GitHub
parent 29be360954
commit 0565ab6d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 9 deletions

View File

@ -50,9 +50,24 @@ def background_investigation_node(
logger.info("background investigation node is running.") logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config) configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content query = state["messages"][-1].content
background_investigation_results = get_web_search_tool( if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
configurable.max_search_results searched_content = LoggedTavilySearch(
).invoke(query) max_results=configurable.max_search_results
).invoke(query)
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
{"title": elem["title"], "content": elem["content"]}
for elem in searched_content
]
else:
logger.error(
f"Tavily search returned malformed response: {searched_content}"
)
else:
background_investigation_results = get_web_search_tool(
configurable.max_search_results
).invoke(query)
return Command( return Command(
update={ update={
"background_investigation_results": json.dumps( "background_investigation_results": json.dumps(

View File

@ -68,7 +68,7 @@ def mock_web_search_tool():
yield mock yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY, "other"]) @pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
def test_background_investigation_node_tavily( def test_background_investigation_node_tavily(
mock_state, mock_state,
mock_tavily_search, mock_tavily_search,
@ -93,10 +93,8 @@ def test_background_investigation_node_tavily(
results = json.loads(update["background_investigation_results"]) results = json.loads(update["background_investigation_results"])
assert isinstance(results, list) assert isinstance(results, list)
if search_engine == SearchEngine.TAVILY: if search_engine == SearchEngine.TAVILY.value:
mock_tavily_search.return_value.invoke.assert_called_once_with( mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
{"query": "test query"}
)
assert len(results) == 2 assert len(results) == 2
assert results[0]["title"] == "Test Title 1" assert results[0]["title"] == "Test Title 1"
assert results[0]["content"] == "Test Content 1" assert results[0]["content"] == "Test Content 1"
@ -111,7 +109,7 @@ def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
): ):
"""Test background_investigation_node with malformed Tavily response""" """Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY): with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
# Mock a malformed response # Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response" mock_tavily_search.return_value.invoke.return_value = "invalid response"