Add tavily as web searh tool. (#5349)

### What problem does this PR solve?

#5198

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu 2025-02-26 10:21:04 +08:00 committed by GitHub
parent e5e9ca0015
commit 53b9e7b52f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 3248 additions and 3080 deletions

View File

@ -64,6 +64,7 @@ def structure_answer(conv, ans, message_id, session_id):
"dataset_id": get_value(chunk, "kb_id", "dataset_id"), "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"image_id": get_value(chunk, "image_id", "img_id"), "image_id": get_value(chunk, "image_id", "img_id"),
"positions": get_value(chunk, "positions", "position_int"), "positions": get_value(chunk, "positions", "position_int"),
"url": chunk.get("url")
} for chunk in reference.get("chunks", [])] } for chunk in reference.get("chunks", [])]
reference["chunks"] = chunk_list reference["chunks"] = chunk_list

View File

@ -40,6 +40,7 @@ from rag.nlp.search import index_name
from rag.settings import TAG_FLD from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.utils.tavily_conn import Tavily
class DialogService(CommonService): class DialogService(CommonService):
@ -125,6 +126,7 @@ def kb_prompt(kbinfos, max_tokens):
chunks_num += 1 chunks_num += 1
if max_tokens * 0.97 < used_token_count: if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i] knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
break break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
@ -132,7 +134,7 @@ def kb_prompt(kbinfos, max_tokens):
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []}) doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for ck in kbinfos["chunks"][:chunks_num]: for ck in kbinfos["chunks"][:chunks_num]:
doc2chunks[ck["docnm_kwd"]]["chunks"].append(ck["content_with_weight"]) doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"])
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {}) doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
knowledges = [] knowledges = []
@ -295,7 +297,7 @@ def chat(dialog, messages, stream=True, **kwargs):
knowledges = [] knowledges = []
if prompt_config.get("reasoning", False): if prompt_config.get("reasoning", False):
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3): for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, prompt_config, MAX_SEARCH_LIMIT=3):
if isinstance(think, str): if isinstance(think, str):
thought = think thought = think
knowledges = [t for t in think.split("\n") if t] knowledges = [t for t in think.split("\n") if t]
@ -309,6 +311,11 @@ def chat(dialog, messages, stream=True, **kwargs):
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
rank_feature=label_question(" ".join(questions), kbs) rank_feature=label_question(" ".join(questions), kbs)
) )
if prompt_config.get("tavily_api_key"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(questions))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"): if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(" ".join(questions), ck = settings.kg_retrievaler.retrieval(" ".join(questions),
tenant_ids, tenant_ids,
@ -852,7 +859,7 @@ Output:
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle, def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3, tenant_ids: list[str], kb_ids: list[str], prompt_config, MAX_SEARCH_LIMIT: int = 3,
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3): top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
BEGIN_SEARCH_QUERY = "<|begin_search_query|>" BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
END_SEARCH_QUERY = "<|end_search_query|>" END_SEARCH_QUERY = "<|end_search_query|>"
@ -1023,10 +1030,28 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
truncated_prev_reasoning += '...\n\n' truncated_prev_reasoning += '...\n\n'
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n') truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
# Retrieval procedure:
# 1. KB search
# 2. Web search (optional)
# 3. KG search (optional)
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n, kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
similarity_threshold, similarity_threshold,
vector_similarity_weight vector_similarity_weight
) )
if prompt_config.get("tavily_api_key", "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"):
tav = Tavily(prompt_config["tavily_api_key"])
tav_res = tav.retrieve_chunks(" ".join(search_query))
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retrievaler.retrieval(search_query,
tenant_ids,
kb_ids,
embd_mdl,
chat_mdl)
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
# Merge chunk info for citations # Merge chunk info for citations
if not chunk_info["chunks"]: if not chunk_info["chunks"]:
for k in chunk_info.keys(): for k in chunk_info.keys():
@ -1048,7 +1073,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
relevant_extraction_prompt.format( relevant_extraction_prompt.format(
prev_reasoning=truncated_prev_reasoning, prev_reasoning=truncated_prev_reasoning,
search_query=search_query, search_query=search_query,
document="\n".join(kb_prompt(kbinfos, 512)) document="\n".join(kb_prompt(kbinfos, 4096))
), ),
[{"role": "user", [{"role": "user",
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],

View File

@ -97,6 +97,7 @@ dependencies = [
"six==1.16.0", "six==1.16.0",
"strenum==0.4.15", "strenum==0.4.15",
"tabulate==0.9.0", "tabulate==0.9.0",
"tavily-python==0.5.1",
"tencentcloud-sdk-python==3.0.1215", "tencentcloud-sdk-python==3.0.1215",
"tika==2.6.0", "tika==2.6.0",
"tiktoken==0.7.0", "tiktoken==0.7.0",

View File

@ -206,6 +206,8 @@ class FulltextQueryer:
sims = CosineSimilarity([avec], bvecs) sims = CosineSimilarity([avec], bvecs)
tksim = self.token_similarity(atks, btkss) tksim = self.token_similarity(atks, btkss)
if np.sum(sims[0]) == 0:
return np.array(tksim), tksim, sims[0]
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0] return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
def token_similarity(self, atks, btkss): def token_similarity(self, atks, btkss):

66
rag/utils/tavily_conn.py Normal file
View File

@ -0,0 +1,66 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from tavily import TavilyClient
from api.utils import get_uuid
from rag.nlp import rag_tokenizer
class Tavily:
def __init__(self, api_key: str):
self.tavily_client = TavilyClient(api_key=api_key)
def search(self, query):
try:
response = self.tavily_client.search(
query=query,
search_depth="advanced"
)
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]]
except Exception as e:
logging.exception(e)
return []
def retrieve_chunks(self, question):
chunks = []
aggs = []
for r in self.search(question):
id = get_uuid()
chunks.append({
"chunk_id": id,
"content_ltks": rag_tokenizer.tokenize(r["content"]),
"content_with_weight": r["content"],
"doc_id": id,
"docnm_kwd": r["title"],
"kb_id": [],
"important_kwd": [],
"image_id": "",
"similarity": r["score"],
"vector_similarity": 1.,
"term_similarity": 0,
"vector": [],
"positions": [],
"url": r["url"]
})
aggs.append({
"doc_name": r["title"],
"doc_id": id,
"count": 1,
"url": r["url"]
})
logging.info("[Tavily]: "+r["content"][:128]+"...")
return {"chunks": chunks, "doc_aggs": aggs}

6225
uv.lock generated

File diff suppressed because it is too large Load Diff