From 3444cb15e322561d16437df5ee43581b66dfbc73 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Fri, 21 Feb 2025 18:32:32 +0800 Subject: [PATCH] Refine search query. (#5235) ### What problem does this PR solve? #5173 #5214 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/db/services/dialog_service.py | 164 +++++++++++++++--------------- rag/nlp/__init__.py | 8 +- 2 files changed, 85 insertions(+), 87 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 937626764..33d4f3695 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -876,12 +876,16 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL "Once you have all the information you need, continue your reasoning.\n\n" "-- Example --\n" "Question: \"Find the minimum number of vertices in a Steiner tree that includes all specified vertices in a given tree.\"\n" - "Assistant thinking steps:\n" - "- I need to understand what a Steiner tree is and how to compute the minimum number of vertices required to include all specified vertices in a given tree.\n\n" "Assistant:\n" - f"{BEGIN_SEARCH_QUERY}Minimum Steiner Tree problem in trees{END_SEARCH_QUERY}\n\n" - "(System returns processed information from relevant web pages)\n\n" - "Assistant continues reasoning with the new information...\n\n" + " - I need to understand what a Steiner tree is.\n\n" + f" {BEGIN_SEARCH_QUERY}What's Steiner tree{END_SEARCH_QUERY}\n\n" + f" {BEGIN_SEARCH_RESULT}\n(System returns processed information from relevant web pages)\n{END_SEARCH_RESULT}\n\n" + "User:\nContinues reasoning with the new information.\n\n" + "Assistant:\n" + " - I need to understand what the difference between minimum number of vertices and edges in the Steiner tree is.\n\n" + f" {BEGIN_SEARCH_QUERY}What's the difference between minimum number of vertices and edges in the Steiner tree{END_SEARCH_QUERY}\n\n" + f" {BEGIN_SEARCH_RESULT}\n(System returns processed information from relevant web pages)\n{END_SEARCH_RESULT}\n\n" + "User:\nContinues reasoning with the new information...\n\n" "**Remember**:\n" f"- You have a dataset to search, so you just provide a proper search query.\n" f"- Use {BEGIN_SEARCH_QUERY} to request a dataset search and end with {END_SEARCH_QUERY}.\n" @@ -943,7 +947,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL query_think = "" if msg_hisotry[-1]["role"] != "user": - msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information...\n"}) + msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information.\n"}) for ans in chat_mdl.chat_streamly(reason_prompt, msg_hisotry, {"temperature": 0.7}): ans = re.sub(r".*", "", ans, flags=re.DOTALL) if not ans: @@ -954,86 +958,84 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL think += rm_query_tags(query_think) all_reasoning_steps.append(query_think) msg_hisotry.append({"role": "assistant", "content": query_think}) - search_query = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) - if not search_query: + queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY) + if not queries: if ii > 0: break - search_query = question - txt = f"\n{BEGIN_SEARCH_QUERY}{question}{END_SEARCH_QUERY}\n\n" - think += txt - msg_hisotry[-1]["content"] += txt + queries = [question] - logging.info(f"[THINK]Query: {ii}. {search_query}") - think += f"\n\n> {ii+1}. {search_query}\n\n" - yield {"answer": think + "", "reference": {}, "audio_binary": None} + for search_query in queries: + logging.info(f"[THINK]Query: {ii}. {search_query}") + think += f"\n\n> {ii+1}. {search_query}\n\n" + yield {"answer": think + "", "reference": {}, "audio_binary": None} - summary_think = "" - # The search query has been searched in previous steps. - if search_query in executed_search_queries: - summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n" - yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None} - all_reasoning_steps.append(summary_think) - msg_hisotry.append({"role": "assistant", "content": summary_think}) - think += summary_think - continue - - truncated_prev_reasoning = "" - for i, step in enumerate(all_reasoning_steps): - truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n" - - prev_steps = truncated_prev_reasoning.split('\n\n') - if len(prev_steps) <= 5: - truncated_prev_reasoning = '\n\n'.join(prev_steps) - else: - truncated_prev_reasoning = '' - for i, step in enumerate(prev_steps): - if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step: - truncated_prev_reasoning += step + '\n\n' - else: - if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n': - truncated_prev_reasoning += '...\n\n' - truncated_prev_reasoning = truncated_prev_reasoning.strip('\n') - - kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n, - similarity_threshold, - vector_similarity_weight - ) - # Merge chunk info for citations - if not chunk_info["chunks"]: - for k in chunk_info.keys(): - chunk_info[k] = kbinfos[k] - else: - cids = [c["chunk_id"] for c in chunk_info["chunks"]] - for c in kbinfos["chunks"]: - if c["chunk_id"] in cids: - continue - chunk_info["chunks"].append(c) - dids = [d["doc_id"] for d in chunk_info["doc_aggs"]] - for d in kbinfos["doc_aggs"]: - if d["doc_id"] in dids: - continue - chunk_info["doc_aggs"].append(d) - - think += "\n\n" - for ans in chat_mdl.chat_streamly( - relevant_extraction_prompt.format( - prev_reasoning=truncated_prev_reasoning, - search_query=search_query, - document="\n".join(kb_prompt(kbinfos, 512)) - ), - [{"role": "user", - "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], - {"temperature": 0.7}): - ans = re.sub(r".*", "", ans, flags=re.DOTALL) - if not ans: + summary_think = "" + # The search query has been searched in previous steps. + if search_query in executed_search_queries: + summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n" + yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None} + all_reasoning_steps.append(summary_think) + msg_hisotry.append({"role": "assistant", "content": summary_think}) + think += summary_think continue - summary_think = ans - yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} - all_reasoning_steps.append(summary_think) - msg_hisotry.append( - {"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"}) - think += rm_result_tags(summary_think) - logging.info(f"[THINK]Summary: {ii}. {summary_think}") + truncated_prev_reasoning = "" + for i, step in enumerate(all_reasoning_steps): + truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n" + + prev_steps = truncated_prev_reasoning.split('\n\n') + if len(prev_steps) <= 5: + truncated_prev_reasoning = '\n\n'.join(prev_steps) + else: + truncated_prev_reasoning = '' + for i, step in enumerate(prev_steps): + if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step: + truncated_prev_reasoning += step + '\n\n' + else: + if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n': + truncated_prev_reasoning += '...\n\n' + truncated_prev_reasoning = truncated_prev_reasoning.strip('\n') + + kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n, + similarity_threshold, + vector_similarity_weight + ) + # Merge chunk info for citations + if not chunk_info["chunks"]: + for k in chunk_info.keys(): + chunk_info[k] = kbinfos[k] + else: + cids = [c["chunk_id"] for c in chunk_info["chunks"]] + for c in kbinfos["chunks"]: + if c["chunk_id"] in cids: + continue + chunk_info["chunks"].append(c) + dids = [d["doc_id"] for d in chunk_info["doc_aggs"]] + for d in kbinfos["doc_aggs"]: + if d["doc_id"] in dids: + continue + chunk_info["doc_aggs"].append(d) + + think += "\n\n" + for ans in chat_mdl.chat_streamly( + relevant_extraction_prompt.format( + prev_reasoning=truncated_prev_reasoning, + search_query=search_query, + document="\n".join(kb_prompt(kbinfos, 512)) + ), + [{"role": "user", + "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}], + {"temperature": 0.7}): + ans = re.sub(r".*", "", ans, flags=re.DOTALL) + if not ans: + continue + summary_think = ans + yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} + + all_reasoning_steps.append(summary_think) + msg_hisotry.append( + {"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"}) + think += rm_result_tags(summary_think) + logging.info(f"[THINK]Summary: {ii}. {summary_think}") yield think + "" diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py index beabe61eb..c98597bce 100644 --- a/rag/nlp/__init__.py +++ b/rag/nlp/__init__.py @@ -17,7 +17,6 @@ import logging import random from collections import Counter -from typing import Optional from rag.utils import num_tokens_from_string from . import rag_tokenizer @@ -604,9 +603,6 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"): return cks, images -def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]: +def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]: pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag) - matches = re.findall(pattern, text, flags=re.DOTALL) - if matches: - return matches[-1].strip() - return None \ No newline at end of file + return re.findall(pattern, text, flags=re.DOTALL)