mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 09:09:00 +08:00
Add tavily as web searh tool. (#5349)
### What problem does this PR solve? #5198 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
e5e9ca0015
commit
53b9e7b52f
@ -64,6 +64,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url")
|
||||
} for chunk in reference.get("chunks", [])]
|
||||
|
||||
reference["chunks"] = chunk_list
|
||||
|
@ -40,6 +40,7 @@ from rag.nlp.search import index_name
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
@ -125,6 +126,7 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
|
||||
break
|
||||
|
||||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||||
@ -132,7 +134,7 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
|
||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||
for ck in kbinfos["chunks"][:chunks_num]:
|
||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(ck["content_with_weight"])
|
||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"])
|
||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||
|
||||
knowledges = []
|
||||
@ -295,7 +297,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3):
|
||||
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, prompt_config, MAX_SEARCH_LIMIT=3):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
@ -309,6 +311,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs)
|
||||
)
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
||||
tenant_ids,
|
||||
@ -852,7 +859,7 @@ Output:
|
||||
|
||||
|
||||
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
|
||||
tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3,
|
||||
tenant_ids: list[str], kb_ids: list[str], prompt_config, MAX_SEARCH_LIMIT: int = 3,
|
||||
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
|
||||
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
||||
END_SEARCH_QUERY = "<|end_search_query|>"
|
||||
@ -1023,10 +1030,28 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
|
||||
truncated_prev_reasoning += '...\n\n'
|
||||
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
|
||||
|
||||
# Retrieval procedure:
|
||||
# 1. KB search
|
||||
# 2. Web search (optional)
|
||||
# 3. KG search (optional)
|
||||
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
|
||||
similarity_threshold,
|
||||
vector_similarity_weight
|
||||
)
|
||||
if prompt_config.get("tavily_api_key", "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"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(search_query))
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(search_query,
|
||||
tenant_ids,
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
chat_mdl)
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
# Merge chunk info for citations
|
||||
if not chunk_info["chunks"]:
|
||||
for k in chunk_info.keys():
|
||||
@ -1048,7 +1073,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
|
||||
relevant_extraction_prompt.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
document="\n".join(kb_prompt(kbinfos, 512))
|
||||
document="\n".join(kb_prompt(kbinfos, 4096))
|
||||
),
|
||||
[{"role": "user",
|
||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||
|
@ -97,6 +97,7 @@ dependencies = [
|
||||
"six==1.16.0",
|
||||
"strenum==0.4.15",
|
||||
"tabulate==0.9.0",
|
||||
"tavily-python==0.5.1",
|
||||
"tencentcloud-sdk-python==3.0.1215",
|
||||
"tika==2.6.0",
|
||||
"tiktoken==0.7.0",
|
||||
|
@ -206,6 +206,8 @@ class FulltextQueryer:
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
|
66
rag/utils/tavily_conn.py
Normal file
66
rag/utils/tavily_conn.py
Normal file
@ -0,0 +1,66 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from tavily import TavilyClient
|
||||
from api.utils import get_uuid
|
||||
from rag.nlp import rag_tokenizer
|
||||
|
||||
|
||||
class Tavily:
|
||||
def __init__(self, api_key: str):
|
||||
self.tavily_client = TavilyClient(api_key=api_key)
|
||||
|
||||
def search(self, query):
|
||||
try:
|
||||
response = self.tavily_client.search(
|
||||
query=query,
|
||||
search_depth="advanced"
|
||||
)
|
||||
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
return []
|
||||
|
||||
def retrieve_chunks(self, question):
|
||||
chunks = []
|
||||
aggs = []
|
||||
for r in self.search(question):
|
||||
id = get_uuid()
|
||||
chunks.append({
|
||||
"chunk_id": id,
|
||||
"content_ltks": rag_tokenizer.tokenize(r["content"]),
|
||||
"content_with_weight": r["content"],
|
||||
"doc_id": id,
|
||||
"docnm_kwd": r["title"],
|
||||
"kb_id": [],
|
||||
"important_kwd": [],
|
||||
"image_id": "",
|
||||
"similarity": r["score"],
|
||||
"vector_similarity": 1.,
|
||||
"term_similarity": 0,
|
||||
"vector": [],
|
||||
"positions": [],
|
||||
"url": r["url"]
|
||||
})
|
||||
aggs.append({
|
||||
"doc_name": r["title"],
|
||||
"doc_id": id,
|
||||
"count": 1,
|
||||
"url": r["url"]
|
||||
})
|
||||
logging.info("[Tavily]: "+r["content"][:128]+"...")
|
||||
return {"chunks": chunks, "doc_aggs": aggs}
|
Loading…
x
Reference in New Issue
Block a user