mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 15:59:01 +08:00
Apply agentic searching. (#5196)
### What problem does this PR solve? #5173 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
744ff55c62
commit
7b3d700d5f
@ -17,6 +17,8 @@ import logging
|
||||
import binascii
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
|
||||
import json_repair
|
||||
import re
|
||||
from collections import defaultdict
|
||||
@ -33,6 +35,7 @@ from api.db.services.llm_service import TenantLLMService, LLMBundle
|
||||
from api import settings
|
||||
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp import extract_between
|
||||
from rag.nlp.search import index_name
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
@ -135,7 +138,7 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
knowledges = []
|
||||
for nm, cks_meta in doc2chunks.items():
|
||||
txt = f"Document: {nm} \n"
|
||||
for k,v in cks_meta["meta"].items():
|
||||
for k, v in cks_meta["meta"].items():
|
||||
txt += f"{k}: {v}\n"
|
||||
txt += "Relevant fragments as following:\n"
|
||||
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
||||
@ -246,9 +249,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
bind_reranker_ts = timer()
|
||||
generate_keyword_ts = bind_reranker_ts
|
||||
thought = ""
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
knowledges = []
|
||||
else:
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
@ -256,28 +261,37 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs)
|
||||
)
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
||||
tenant_ids,
|
||||
dialog.kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
else:
|
||||
yield think
|
||||
else:
|
||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs)
|
||||
)
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
||||
tenant_ids,
|
||||
dialog.kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
retrieval_ts = timer()
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
retrieval_ts = timer()
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
||||
@ -302,9 +316,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
def decorate_answer(answer):
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
|
||||
|
||||
finish_chat_ts = timer()
|
||||
|
||||
refs = []
|
||||
ans = answer.split("</think>")
|
||||
think = ""
|
||||
if len(ans) == 2:
|
||||
think = ans[0] + "</think>"
|
||||
answer = ans[1]
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
@ -342,22 +359,24 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
||||
|
||||
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
||||
return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)}
|
||||
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(answer)
|
||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
yield decorate_answer(thought+answer)
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
@ -798,3 +817,191 @@ Output:
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
|
||||
tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3,
|
||||
top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
|
||||
BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
|
||||
END_SEARCH_QUERY = "<|end_search_query|>"
|
||||
BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
|
||||
END_SEARCH_RESULT = "<|end_search_result|>"
|
||||
|
||||
def rm_query_tags(line):
|
||||
pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY)
|
||||
return re.sub(pattern, "", line)
|
||||
|
||||
def rm_result_tags(line):
|
||||
pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT)
|
||||
return re.sub(pattern, "", line)
|
||||
|
||||
reason_prompt = (
|
||||
"You are a reasoning assistant with the ability to perform dataset searches to help "
|
||||
"you answer the user's question accurately. You have special tools:\n\n"
|
||||
f"- To perform a search: write {BEGIN_SEARCH_QUERY} your query here {END_SEARCH_QUERY}.\n"
|
||||
f"Then, the system will search and analyze relevant content, then provide you with helpful information in the format {BEGIN_SEARCH_RESULT} ...search results... {END_SEARCH_RESULT}.\n\n"
|
||||
f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to {MAX_SEARCH_LIMIT}.\n\n"
|
||||
"Once you have all the information you need, continue your reasoning.\n\n"
|
||||
"-- Example --\n"
|
||||
"Question: \"Find the minimum number of vertices in a Steiner tree that includes all specified vertices in a given tree.\"\n"
|
||||
"Assistant thinking steps:\n"
|
||||
"- I need to understand what a Steiner tree is and how to compute the minimum number of vertices required to include all specified vertices in a given tree.\n\n"
|
||||
"Assistant:\n"
|
||||
f"{BEGIN_SEARCH_QUERY}Minimum Steiner Tree problem in trees{END_SEARCH_QUERY}\n\n"
|
||||
"(System returns processed information from relevant web pages)\n\n"
|
||||
"Assistant continues reasoning with the new information...\n\n"
|
||||
"**Remember**:\n"
|
||||
f"- You have a dataset to search, so you just provide a proper search query.\n"
|
||||
f"- Use {BEGIN_SEARCH_QUERY} to request a dataset search and end with {END_SEARCH_QUERY}.\n"
|
||||
"- The language of query MUST be as the same as 'Question' or 'search result'.\n"
|
||||
"- When done searching, continue your reasoning.\n\n"
|
||||
'Please answer the following question. You should think step by step to solve it.\n\n'
|
||||
)
|
||||
|
||||
relevant_extraction_prompt = """**Task Instruction:**
|
||||
|
||||
You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question.
|
||||
|
||||
**Guidelines:**
|
||||
|
||||
1. **Analyze the Searched Web Pages:**
|
||||
- Carefully review the content of each searched web page.
|
||||
- Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question.
|
||||
|
||||
2. **Extract Relevant Information:**
|
||||
- Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**.
|
||||
- Ensure that the extracted information is accurate and relevant.
|
||||
|
||||
3. **Output Format:**
|
||||
- **If the web pages provide helpful information for current search query:** Present the information beginning with `**Final Information**` as shown below.
|
||||
- The language of query **MUST BE** as the same as 'Search Query' or 'Web Pages'.\n"
|
||||
**Final Information**
|
||||
|
||||
[Helpful information]
|
||||
|
||||
- **If the web pages do not provide any helpful information for current search query:** Output the following text.
|
||||
|
||||
**Final Information**
|
||||
|
||||
No helpful information found.
|
||||
|
||||
**Inputs:**
|
||||
- **Previous Reasoning Steps:**
|
||||
{prev_reasoning}
|
||||
|
||||
- **Current Search Query:**
|
||||
{search_query}
|
||||
|
||||
- **Searched Web Pages:**
|
||||
{document}
|
||||
|
||||
"""
|
||||
|
||||
executed_search_queries = []
|
||||
msg_hisotry = [{"role": "user", "content": f'Question:\n{question}\n\n'}]
|
||||
all_reasoning_steps = []
|
||||
think = "<think>"
|
||||
for ii in range(MAX_SEARCH_LIMIT + 1):
|
||||
if ii == MAX_SEARCH_LIMIT - 1:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_hisotry.append({"role": "assistant", "content": summary_think})
|
||||
break
|
||||
|
||||
query_think = ""
|
||||
if msg_hisotry[-1]["role"] != "user":
|
||||
msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information...\n"})
|
||||
for ans in chat_mdl.chat_streamly(reason_prompt, msg_hisotry, {"temperature": 0.7}):
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
query_think = ans
|
||||
yield {"answer": think + rm_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
think += rm_query_tags(query_think)
|
||||
all_reasoning_steps.append(query_think)
|
||||
msg_hisotry.append({"role": "assistant", "content": query_think})
|
||||
search_query = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
||||
if not search_query:
|
||||
if ii > 0:
|
||||
break
|
||||
search_query = question
|
||||
txt = f"\n{BEGIN_SEARCH_QUERY}{question}{END_SEARCH_QUERY}\n\n"
|
||||
think += txt
|
||||
msg_hisotry[-1]["content"] += txt
|
||||
|
||||
logging.info(f"[THINK]Query: {ii}. {search_query}")
|
||||
think += f"\n\n> {ii+1}. {search_query}\n\n"
|
||||
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
summary_think = ""
|
||||
# The search query has been searched in previous steps.
|
||||
if search_query in executed_search_queries:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_hisotry.append({"role": "assistant", "content": summary_think})
|
||||
think += summary_think
|
||||
continue
|
||||
|
||||
truncated_prev_reasoning = ""
|
||||
for i, step in enumerate(all_reasoning_steps):
|
||||
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
|
||||
|
||||
prev_steps = truncated_prev_reasoning.split('\n\n')
|
||||
if len(prev_steps) <= 5:
|
||||
truncated_prev_reasoning = '\n\n'.join(prev_steps)
|
||||
else:
|
||||
truncated_prev_reasoning = ''
|
||||
for i, step in enumerate(prev_steps):
|
||||
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
|
||||
truncated_prev_reasoning += step + '\n\n'
|
||||
else:
|
||||
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
|
||||
truncated_prev_reasoning += '...\n\n'
|
||||
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
|
||||
|
||||
kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
|
||||
similarity_threshold,
|
||||
vector_similarity_weight
|
||||
)
|
||||
# Merge chunk info for citations
|
||||
if not chunk_info["chunks"]:
|
||||
for k in chunk_info.keys():
|
||||
chunk_info[k] = kbinfos[k]
|
||||
else:
|
||||
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
|
||||
for c in kbinfos["chunks"]:
|
||||
if c["chunk_id"] in cids:
|
||||
continue
|
||||
chunk_info["chunks"].append(c)
|
||||
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
|
||||
for d in kbinfos["doc_aggs"]:
|
||||
if d["doc_id"] in dids:
|
||||
continue
|
||||
chunk_info["doc_aggs"].append(d)
|
||||
|
||||
think += "\n\n"
|
||||
for ans in chat_mdl.chat_streamly(
|
||||
relevant_extraction_prompt.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
document="\n".join(kb_prompt(kbinfos, 512))
|
||||
),
|
||||
[{"role": "user",
|
||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||
{"temperature": 0.7}):
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
summary_think = ans
|
||||
yield {"answer": think + rm_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_hisotry.append(
|
||||
{"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
|
||||
think += rm_result_tags(summary_think)
|
||||
logging.info(f"[THINK]Summary: {ii}. {summary_think}")
|
||||
|
||||
yield think + "</think>"
|
||||
|
@ -17,6 +17,7 @@
|
||||
import logging
|
||||
import random
|
||||
from collections import Counter
|
||||
from typing import Optional
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
@ -601,3 +602,11 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
add_chunk(sec, image, '')
|
||||
|
||||
return cks, images
|
||||
|
||||
|
||||
def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]:
|
||||
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
|
||||
matches = re.findall(pattern, text, flags=re.DOTALL)
|
||||
if matches:
|
||||
return matches[-1].strip()
|
||||
return None
|
@ -15,7 +15,6 @@
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rag.settings import TAG_FLD, PAGERANK_FLD
|
||||
@ -259,7 +258,7 @@ class Dealer:
|
||||
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
||||
for i in search_res.ids:
|
||||
nor, denor = 0, 0
|
||||
for t, sc in json.loads(search_res.field[i].get(TAG_FLD, "{}")).items():
|
||||
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
|
||||
if t in query_rfea:
|
||||
nor += query_rfea[t] * sc
|
||||
denor += sc * sc
|
||||
|
Loading…
x
Reference in New Issue
Block a user