From 7b3d700d5f92d3fa0549fce4d31390dbc3b0e7a2 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 20 Feb 2025 17:41:01 +0800 Subject: [PATCH] Apply agentic searching. (#5196) ### What problem does this PR solve? #5173 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/db/services/dialog_service.py | 257 +++++++++++++++++++++++++++--- rag/nlp/__init__.py | 9 ++ rag/nlp/search.py | 3 +- 3 files changed, 242 insertions(+), 27 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 835ed80ca..3b68ca1b4 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -17,6 +17,8 @@ import logging import binascii import os import json +import time + import json_repair import re from collections import defaultdict @@ -33,6 +35,7 @@ from api.db.services.llm_service import TenantLLMService, LLMBundle from api import settings from graphrag.utils import get_tags_from_cache, set_tags_to_cache from rag.app.resume import forbidden_select_fields4resume +from rag.nlp import extract_between from rag.nlp.search import index_name from rag.settings import TAG_FLD from rag.utils import rmSpace, num_tokens_from_string, encoder @@ -135,7 +138,7 @@ def kb_prompt(kbinfos, max_tokens): knowledges = [] for nm, cks_meta in doc2chunks.items(): txt = f"Document: {nm} \n" - for k,v in cks_meta["meta"].items(): + for k, v in cks_meta["meta"].items(): txt += f"{k}: {v}\n" txt += "Relevant fragments as following:\n" for i, chunk in enumerate(cks_meta["chunks"], 1): @@ -246,9 +249,11 @@ def chat(dialog, messages, stream=True, **kwargs): bind_reranker_ts = timer() generate_keyword_ts = bind_reranker_ts + thought = "" + kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]: - kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} + knowledges = [] else: if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) @@ -256,28 +261,37 @@ def chat(dialog, messages, stream=True, **kwargs): tenant_ids = list(set([kb.tenant_id for kb in kbs])) - kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, - dialog.similarity_threshold, - dialog.vector_similarity_weight, - doc_ids=attachments, - top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, - rank_feature=label_question(" ".join(questions), kbs) - ) - if prompt_config.get("use_kg"): - ck = settings.kg_retrievaler.retrieval(" ".join(questions), - tenant_ids, - dialog.kb_ids, - embd_mdl, - LLMBundle(dialog.tenant_id, LLMType.CHAT)) - if ck["content_with_weight"]: - kbinfos["chunks"].insert(0, ck) + knowledges = [] + if prompt_config.get("reasoning", False): + for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3): + if isinstance(think, str): + thought = think + knowledges = [t for t in think.split("\n") if t] + else: + yield think + else: + kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, + dialog.similarity_threshold, + dialog.vector_similarity_weight, + doc_ids=attachments, + top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, + rank_feature=label_question(" ".join(questions), kbs) + ) + if prompt_config.get("use_kg"): + ck = settings.kg_retrievaler.retrieval(" ".join(questions), + tenant_ids, + dialog.kb_ids, + embd_mdl, + LLMBundle(dialog.tenant_id, LLMType.CHAT)) + if ck["content_with_weight"]: + kbinfos["chunks"].insert(0, ck) - retrieval_ts = timer() + knowledges = kb_prompt(kbinfos, max_tokens) - knowledges = kb_prompt(kbinfos, max_tokens) logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) + retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)} @@ -302,9 +316,12 @@ def chat(dialog, messages, stream=True, **kwargs): def decorate_answer(answer): nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts - finish_chat_ts = timer() - refs = [] + ans = answer.split("") + think = "" + if len(ans) == 2: + think = ans[0] + "" + answer = ans[1] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] @@ -342,22 +359,24 @@ def chat(dialog, messages, stream=True, **kwargs): generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms" - return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)} + return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if stream: last_ans = "" answer = "" for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): + if thought: + ans = re.sub(r".*", "", ans, flags=re.DOTALL) answer = ans delta_ans = ans[len(last_ans):] if num_tokens_from_string(delta_ans) < 16: continue last_ans = answer - yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} delta_ans = answer[len(last_ans):] if delta_ans: - yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - yield decorate_answer(answer) + yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + yield decorate_answer(thought+answer) else: answer = chat_mdl.chat(prompt, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") @@ -798,3 +817,191 @@ Output: except Exception as e: logging.exception(f"JSON parsing error: {result} -> {e}") raise e + + +def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle, + tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3, + top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3): + BEGIN_SEARCH_QUERY = "<|begin_search_query|>" + END_SEARCH_QUERY = "<|end_search_query|>" + BEGIN_SEARCH_RESULT = "<|begin_search_result|>" + END_SEARCH_RESULT = "<|end_search_result|>" + + def rm_query_tags(line): + pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY) + return re.sub(pattern, "", line) + + def rm_result_tags(line): + pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT) + return re.sub(pattern, "", line) + + reason_prompt = ( + "You are a reasoning assistant with the ability to perform dataset searches to help " + "you answer the user's question accurately. You have special tools:\n\n" + f"- To perform a search: write {BEGIN_SEARCH_QUERY} your query here {END_SEARCH_QUERY}.\n" + f"Then, the system will search and analyze relevant content, then provide you with helpful information in the format {BEGIN_SEARCH_RESULT} ...search results... {END_SEARCH_RESULT}.\n\n" + f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to {MAX_SEARCH_LIMIT}.\n\n" + "Once you have all the information you need, continue your reasoning.\n\n" + "-- Example --\n" + "Question: \"Find the minimum number of vertices in a Steiner tree that includes all specified vertices in a given tree.\"\n" + "Assistant thinking steps:\n" + "- I need to understand what a Steiner tree is and how to compute the minimum number of vertices required to include all specified vertices in a given tree.\n\n" + "Assistant:\n" + f"{BEGIN_SEARCH_QUERY}Minimum Steiner Tree problem in trees{END_SEARCH_QUERY}\n\n" + "(System returns processed information from relevant web pages)\n\n" + "Assistant continues reasoning with the new information...\n\n" + "**Remember**:\n" + f"- You have a dataset to search, so you just provide a proper search query.\n" + f"- Use {BEGIN_SEARCH_QUERY} to request a dataset search and end with {END_SEARCH_QUERY}.\n" + "- The language of query MUST be as the same as 'Question' or 'search result'.\n" + "- When done searching, continue your reasoning.\n\n" + 'Please answer the following question. You should think step by step to solve it.\n\n' + ) + + relevant_extraction_prompt = """**Task Instruction:** + + You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question. + + **Guidelines:** + + 1. **Analyze the Searched Web Pages:** + - Carefully review the content of each searched web page. + - Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question. + + 2. **Extract Relevant Information:** + - Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**. + - Ensure that the extracted information is accurate and relevant. + + 3. **Output Format:** + - **If the web pages provide helpful information for current search query:** Present the information beginning with `**Final Information**` as shown below. + - The language of query **MUST BE** as the same as 'Search Query' or 'Web Pages'.\n" + **Final Information** + + [Helpful information] + + - **If the web pages do not provide any helpful information for current search query:** Output the following text. + + **Final Information** + + No helpful information found. + + **Inputs:** + - **Previous Reasoning Steps:** + {prev_reasoning} + + - **Current Search Query:** + {search_query} + + - **Searched Web Pages:** + {document} + + """ + + executed_search_queries = [] + msg_hisotry = [{"role": "user", "content": f'Question:\n{question}\n\n'}] + all_reasoning_steps = [] + think = "" + for ii in range(MAX_SEARCH_LIMIT + 1): + if ii == MAX_SEARCH_LIMIT - 1: + summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n" + yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None} + all_reasoning_steps.append(summary_think) + msg_hisotry.append({"role": "assistant", "content": summary_think}) + break + + query_think = "" + if msg_hisotry[-1]["role"] != "user": + msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information...\n"}) + for ans in chat_mdl.chat_streamly(reason_prompt, msg_hisotry, {"temperature": 0.7}): + ans = re.sub(r".*", "", ans, flags=re.DOTALL) + if not ans: + continue + query_think = ans + yield {"answer": think + rm_query_tags(query_think) + "", "reference": {}, "audio_binary": None} + + think += rm_query_tags(query_think) + all_reasoning_steps.append(query_think) + msg_hisotry.append({"role": "assistant", "content": query_think}) + search_query = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) + if not search_query: + if ii > 0: + break + search_query = question + txt = f"\n{BEGIN_SEARCH_QUERY}{question}{END_SEARCH_QUERY}\n\n" + think += txt + msg_hisotry[-1]["content"] += txt + + logging.info(f"[THINK]Query: {ii}. {search_query}") + think += f"\n\n> {ii+1}. {search_query}\n\n" + yield {"answer": think + "", "reference": {}, "audio_binary": None} + + summary_think = "" + # The search query has been searched in previous steps. + if search_query in executed_search_queries: + summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n" + yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None} + all_reasoning_steps.append(summary_think) + msg_hisotry.append({"role": "assistant", "content": summary_think}) + think += summary_think + continue + + truncated_prev_reasoning = "" + for i, step in enumerate(all_reasoning_steps): + truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n" + + prev_steps = truncated_prev_reasoning.split('\n\n') + if len(prev_steps) <= 5: + truncated_prev_reasoning = '\n\n'.join(prev_steps) + else: + truncated_prev_reasoning = '' + for i, step in enumerate(prev_steps): + if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step: + truncated_prev_reasoning += step + '\n\n' + else: + if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n': + truncated_prev_reasoning += '...\n\n' + truncated_prev_reasoning = truncated_prev_reasoning.strip('\n') + + kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n, + similarity_threshold, + vector_similarity_weight + ) + # Merge chunk info for citations + if not chunk_info["chunks"]: + for k in chunk_info.keys(): + chunk_info[k] = kbinfos[k] + else: + cids = [c["chunk_id"] for c in chunk_info["chunks"]] + for c in kbinfos["chunks"]: + if c["chunk_id"] in cids: + continue + chunk_info["chunks"].append(c) + dids = [d["doc_id"] for d in chunk_info["doc_aggs"]] + for d in kbinfos["doc_aggs"]: + if d["doc_id"] in dids: + continue + chunk_info["doc_aggs"].append(d) + + think += "\n\n" + for ans in chat_mdl.chat_streamly( + relevant_extraction_prompt.format( + prev_reasoning=truncated_prev_reasoning, + search_query=search_query, + document="\n".join(kb_prompt(kbinfos, 512)) + ), + [{"role": "user", + "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], + {"temperature": 0.7}): + ans = re.sub(r".*", "", ans, flags=re.DOTALL) + if not ans: + continue + summary_think = ans + yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} + + all_reasoning_steps.append(summary_think) + msg_hisotry.append( + {"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"}) + think += rm_result_tags(summary_think) + logging.info(f"[THINK]Summary: {ii}. {summary_think}") + + yield think + "" diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index 5e288a9b1..beabe61eb 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -17,6 +17,7 @@ import logging import random from collections import Counter +from typing import Optional from rag.utils import num_tokens_from_string from . import rag_tokenizer @@ -601,3 +602,11 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): add_chunk(sec, image, '') return cks, images + + +def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]: + pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) + matches = re.findall(pattern, text, flags=re.DOTALL) + if matches: + return matches[-1].strip() + return None \ No newline at end of file diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 86b02d78d..86416cd49 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -15,7 +15,6 @@ # import logging import re -import json from dataclasses import dataclass from rag.settings import TAG_FLD, PAGERANK_FLD @@ -259,7 +258,7 @@ class Dealer: q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD])) for i in search_res.ids: nor, denor = 0, 0 - for t, sc in json.loads(search_res.field[i].get(TAG_FLD, "{}")).items(): + for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items(): if t in query_rfea: nor += query_rfea[t] * sc denor += sc * sc