mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-12 20:19:11 +08:00
Add tavily as web searh tool. (#5349)
### What problem does this PR solve? #5198 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
e5e9ca0015
commit
53b9e7b52f
@ -64,6 +64,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
|||||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||||
"positions": get_value(chunk, "positions", "position_int"),
|
"positions": get_value(chunk, "positions", "position_int"),
|
||||||
|
"url": chunk.get("url")
|
||||||
} for chunk in reference.get("chunks", [])]
|
} for chunk in reference.get("chunks", [])]
|
||||||
|
|
||||||
reference["chunks"] = chunk_list
|
reference["chunks"] = chunk_list
|
||||||
|
@ -40,6 +40,7 @@ from rag.nlp.search import index_name
|
|||||||
from rag.settings import TAG_FLD
|
from rag.settings import TAG_FLD
|
||||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
from rag.utils.tavily_conn import Tavily
|
||||||
|
|
||||||
|
|
||||||
class DialogService(CommonService):
|
class DialogService(CommonService):
|
||||||
@ -125,6 +126,7 @@ def kb_prompt(kbinfos, max_tokens):
|
|||||||
chunks_num += 1
|
chunks_num += 1
|
||||||
if max_tokens * 0.97 < used_token_count:
|
if max_tokens * 0.97 < used_token_count:
|
||||||
knowledges = knowledges[:i]
|
knowledges = knowledges[:i]
|
||||||
|
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}")
|
||||||
break
|
break
|
||||||
|
|
||||||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||||||
@ -132,7 +134,7 @@ def kb_prompt(kbinfos, max_tokens):
|
|||||||
|
|
||||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||||
for ck in kbinfos["chunks"][:chunks_num]:
|
for ck in kbinfos["chunks"][:chunks_num]:
|
||||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(ck["content_with_weight"])
|
doc2chunks[ck["docnm_kwd"]]["chunks"].append((f"URL: {ck['url']}\n" if "url" in ck else "") + ck["content_with_weight"])
|
||||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||||
|
|
||||||
knowledges = []
|
knowledges = []
|
||||||
@ -295,7 +297,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
knowledges = []
|
knowledges = []
|
||||||
if prompt_config.get("reasoning", False):
|
if prompt_config.get("reasoning", False):
|
||||||
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3):
|
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, prompt_config, MAX_SEARCH_LIMIT=3):
|
||||||
if isinstance(think, str):
|
if isinstance(think, str):
|
||||||
thought = think
|
thought = think
|
||||||
knowledges = [t for t in think.split("\n") if t]
|
knowledges = [t for t in think.split("\n") if t]
|
||||||
@ -309,6 +311,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||||
rank_feature=label_question(" ".join(questions), kbs)
|
rank_feature=label_question(" ".join(questions), kbs)
|
||||||
)
|
)
|
||||||
|
if prompt_config.get("tavily_api_key"):
|
||||||
|
tav = Tavily(prompt_config["tavily_api_key"])
|
||||||
|
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||||
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
if prompt_config.get("use_kg"):
|
if prompt_config.get("use_kg"):
|
||||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
@ -852,7 +859,7 @@ Output:
|
|||||||
|
|
||||||
|
|
||||||
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
|
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
|
||||||
tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3,
|
tenant_ids: list[str], kb_ids: list[str], prompt_config, MAX_SEARCH_LIMIT: int = 3,
|
||||||
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
|
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
|
||||||
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
||||||
END_SEARCH_QUERY = "<|end_search_query|>"
|
END_SEARCH_QUERY = "<|end_search_query|>"
|
||||||
@ -1023,10 +1030,28 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
|
|||||||
truncated_prev_reasoning += '...\n\n'
|
truncated_prev_reasoning += '...\n\n'
|
||||||
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
|
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
|
||||||
|
|
||||||
|
# Retrieval procedure:
|
||||||
|
# 1. KB search
|
||||||
|
# 2. Web search (optional)
|
||||||
|
# 3. KG search (optional)
|
||||||
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
|
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
|
||||||
similarity_threshold,
|
similarity_threshold,
|
||||||
vector_similarity_weight
|
vector_similarity_weight
|
||||||
)
|
)
|
||||||
|
if prompt_config.get("tavily_api_key", "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"):
|
||||||
|
tav = Tavily(prompt_config["tavily_api_key"])
|
||||||
|
tav_res = tav.retrieve_chunks(" ".join(search_query))
|
||||||
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
|
if prompt_config.get("use_kg"):
|
||||||
|
ck = settings.kg_retrievaler.retrieval(search_query,
|
||||||
|
tenant_ids,
|
||||||
|
kb_ids,
|
||||||
|
embd_mdl,
|
||||||
|
chat_mdl)
|
||||||
|
if ck["content_with_weight"]:
|
||||||
|
kbinfos["chunks"].insert(0, ck)
|
||||||
|
|
||||||
# Merge chunk info for citations
|
# Merge chunk info for citations
|
||||||
if not chunk_info["chunks"]:
|
if not chunk_info["chunks"]:
|
||||||
for k in chunk_info.keys():
|
for k in chunk_info.keys():
|
||||||
@ -1048,7 +1073,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
|
|||||||
relevant_extraction_prompt.format(
|
relevant_extraction_prompt.format(
|
||||||
prev_reasoning=truncated_prev_reasoning,
|
prev_reasoning=truncated_prev_reasoning,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
document="\n".join(kb_prompt(kbinfos, 512))
|
document="\n".join(kb_prompt(kbinfos, 4096))
|
||||||
),
|
),
|
||||||
[{"role": "user",
|
[{"role": "user",
|
||||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||||
|
@ -97,6 +97,7 @@ dependencies = [
|
|||||||
"six==1.16.0",
|
"six==1.16.0",
|
||||||
"strenum==0.4.15",
|
"strenum==0.4.15",
|
||||||
"tabulate==0.9.0",
|
"tabulate==0.9.0",
|
||||||
|
"tavily-python==0.5.1",
|
||||||
"tencentcloud-sdk-python==3.0.1215",
|
"tencentcloud-sdk-python==3.0.1215",
|
||||||
"tika==2.6.0",
|
"tika==2.6.0",
|
||||||
"tiktoken==0.7.0",
|
"tiktoken==0.7.0",
|
||||||
|
@ -206,6 +206,8 @@ class FulltextQueryer:
|
|||||||
|
|
||||||
sims = CosineSimilarity([avec], bvecs)
|
sims = CosineSimilarity([avec], bvecs)
|
||||||
tksim = self.token_similarity(atks, btkss)
|
tksim = self.token_similarity(atks, btkss)
|
||||||
|
if np.sum(sims[0]) == 0:
|
||||||
|
return np.array(tksim), tksim, sims[0]
|
||||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||||
|
|
||||||
def token_similarity(self, atks, btkss):
|
def token_similarity(self, atks, btkss):
|
||||||
|
66
rag/utils/tavily_conn.py
Normal file
66
rag/utils/tavily_conn.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from tavily import TavilyClient
|
||||||
|
from api.utils import get_uuid
|
||||||
|
from rag.nlp import rag_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Tavily:
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
self.tavily_client = TavilyClient(api_key=api_key)
|
||||||
|
|
||||||
|
def search(self, query):
|
||||||
|
try:
|
||||||
|
response = self.tavily_client.search(
|
||||||
|
query=query,
|
||||||
|
search_depth="advanced"
|
||||||
|
)
|
||||||
|
return [{"url": res["url"], "title": res["title"], "content": res["content"], "score": res["score"]} for res in response["results"]]
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def retrieve_chunks(self, question):
|
||||||
|
chunks = []
|
||||||
|
aggs = []
|
||||||
|
for r in self.search(question):
|
||||||
|
id = get_uuid()
|
||||||
|
chunks.append({
|
||||||
|
"chunk_id": id,
|
||||||
|
"content_ltks": rag_tokenizer.tokenize(r["content"]),
|
||||||
|
"content_with_weight": r["content"],
|
||||||
|
"doc_id": id,
|
||||||
|
"docnm_kwd": r["title"],
|
||||||
|
"kb_id": [],
|
||||||
|
"important_kwd": [],
|
||||||
|
"image_id": "",
|
||||||
|
"similarity": r["score"],
|
||||||
|
"vector_similarity": 1.,
|
||||||
|
"term_similarity": 0,
|
||||||
|
"vector": [],
|
||||||
|
"positions": [],
|
||||||
|
"url": r["url"]
|
||||||
|
})
|
||||||
|
aggs.append({
|
||||||
|
"doc_name": r["title"],
|
||||||
|
"doc_id": id,
|
||||||
|
"count": 1,
|
||||||
|
"url": r["url"]
|
||||||
|
})
|
||||||
|
logging.info("[Tavily]: "+r["content"][:128]+"...")
|
||||||
|
return {"chunks": chunks, "doc_aggs": aggs}
|
Loading…
x
Reference in New Issue
Block a user