diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index 937626764..33d4f3695 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -876,12 +876,16 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
"Once you have all the information you need, continue your reasoning.\n\n"
"-- Example --\n"
"Question: \"Find the minimum number of vertices in a Steiner tree that includes all specified vertices in a given tree.\"\n"
- "Assistant thinking steps:\n"
- "- I need to understand what a Steiner tree is and how to compute the minimum number of vertices required to include all specified vertices in a given tree.\n\n"
"Assistant:\n"
- f"{BEGIN_SEARCH_QUERY}Minimum Steiner Tree problem in trees{END_SEARCH_QUERY}\n\n"
- "(System returns processed information from relevant web pages)\n\n"
- "Assistant continues reasoning with the new information...\n\n"
+ " - I need to understand what a Steiner tree is.\n\n"
+ f" {BEGIN_SEARCH_QUERY}What's Steiner tree{END_SEARCH_QUERY}\n\n"
+ f" {BEGIN_SEARCH_RESULT}\n(System returns processed information from relevant web pages)\n{END_SEARCH_RESULT}\n\n"
+ "User:\nContinues reasoning with the new information.\n\n"
+ "Assistant:\n"
+ " - I need to understand what the difference between minimum number of vertices and edges in the Steiner tree is.\n\n"
+ f" {BEGIN_SEARCH_QUERY}What's the difference between minimum number of vertices and edges in the Steiner tree{END_SEARCH_QUERY}\n\n"
+ f" {BEGIN_SEARCH_RESULT}\n(System returns processed information from relevant web pages)\n{END_SEARCH_RESULT}\n\n"
+ "User:\nContinues reasoning with the new information...\n\n"
"**Remember**:\n"
f"- You have a dataset to search, so you just provide a proper search query.\n"
f"- Use {BEGIN_SEARCH_QUERY} to request a dataset search and end with {END_SEARCH_QUERY}.\n"
@@ -943,7 +947,7 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
query_think = ""
if msg_hisotry[-1]["role"] != "user":
- msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information...\n"})
+ msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
for ans in chat_mdl.chat_streamly(reason_prompt, msg_hisotry, {"temperature": 0.7}):
ans = re.sub(r".*", "", ans, flags=re.DOTALL)
if not ans:
@@ -954,86 +958,84 @@ def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LL
think += rm_query_tags(query_think)
all_reasoning_steps.append(query_think)
msg_hisotry.append({"role": "assistant", "content": query_think})
- search_query = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
- if not search_query:
+ queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
+ if not queries:
if ii > 0:
break
- search_query = question
- txt = f"\n{BEGIN_SEARCH_QUERY}{question}{END_SEARCH_QUERY}\n\n"
- think += txt
- msg_hisotry[-1]["content"] += txt
+ queries = [question]
- logging.info(f"[THINK]Query: {ii}. {search_query}")
- think += f"\n\n> {ii+1}. {search_query}\n\n"
- yield {"answer": think + "", "reference": {}, "audio_binary": None}
+ for search_query in queries:
+ logging.info(f"[THINK]Query: {ii}. {search_query}")
+ think += f"\n\n> {ii+1}. {search_query}\n\n"
+ yield {"answer": think + "", "reference": {}, "audio_binary": None}
- summary_think = ""
- # The search query has been searched in previous steps.
- if search_query in executed_search_queries:
- summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
- yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None}
- all_reasoning_steps.append(summary_think)
- msg_hisotry.append({"role": "assistant", "content": summary_think})
- think += summary_think
- continue
-
- truncated_prev_reasoning = ""
- for i, step in enumerate(all_reasoning_steps):
- truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
-
- prev_steps = truncated_prev_reasoning.split('\n\n')
- if len(prev_steps) <= 5:
- truncated_prev_reasoning = '\n\n'.join(prev_steps)
- else:
- truncated_prev_reasoning = ''
- for i, step in enumerate(prev_steps):
- if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
- truncated_prev_reasoning += step + '\n\n'
- else:
- if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
- truncated_prev_reasoning += '...\n\n'
- truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
-
- kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
- similarity_threshold,
- vector_similarity_weight
- )
- # Merge chunk info for citations
- if not chunk_info["chunks"]:
- for k in chunk_info.keys():
- chunk_info[k] = kbinfos[k]
- else:
- cids = [c["chunk_id"] for c in chunk_info["chunks"]]
- for c in kbinfos["chunks"]:
- if c["chunk_id"] in cids:
- continue
- chunk_info["chunks"].append(c)
- dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
- for d in kbinfos["doc_aggs"]:
- if d["doc_id"] in dids:
- continue
- chunk_info["doc_aggs"].append(d)
-
- think += "\n\n"
- for ans in chat_mdl.chat_streamly(
- relevant_extraction_prompt.format(
- prev_reasoning=truncated_prev_reasoning,
- search_query=search_query,
- document="\n".join(kb_prompt(kbinfos, 512))
- ),
- [{"role": "user",
- "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
- {"temperature": 0.7}):
- ans = re.sub(r".*", "", ans, flags=re.DOTALL)
- if not ans:
+ summary_think = ""
+ # The search query has been searched in previous steps.
+ if search_query in executed_search_queries:
+ summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
+ yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None}
+ all_reasoning_steps.append(summary_think)
+ msg_hisotry.append({"role": "assistant", "content": summary_think})
+ think += summary_think
continue
- summary_think = ans
- yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None}
- all_reasoning_steps.append(summary_think)
- msg_hisotry.append(
- {"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
- think += rm_result_tags(summary_think)
- logging.info(f"[THINK]Summary: {ii}. {summary_think}")
+ truncated_prev_reasoning = ""
+ for i, step in enumerate(all_reasoning_steps):
+ truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
+
+ prev_steps = truncated_prev_reasoning.split('\n\n')
+ if len(prev_steps) <= 5:
+ truncated_prev_reasoning = '\n\n'.join(prev_steps)
+ else:
+ truncated_prev_reasoning = ''
+ for i, step in enumerate(prev_steps):
+ if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
+ truncated_prev_reasoning += step + '\n\n'
+ else:
+ if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
+ truncated_prev_reasoning += '...\n\n'
+ truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
+
+ kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
+ similarity_threshold,
+ vector_similarity_weight
+ )
+ # Merge chunk info for citations
+ if not chunk_info["chunks"]:
+ for k in chunk_info.keys():
+ chunk_info[k] = kbinfos[k]
+ else:
+ cids = [c["chunk_id"] for c in chunk_info["chunks"]]
+ for c in kbinfos["chunks"]:
+ if c["chunk_id"] in cids:
+ continue
+ chunk_info["chunks"].append(c)
+ dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
+ for d in kbinfos["doc_aggs"]:
+ if d["doc_id"] in dids:
+ continue
+ chunk_info["doc_aggs"].append(d)
+
+ think += "\n\n"
+ for ans in chat_mdl.chat_streamly(
+ relevant_extraction_prompt.format(
+ prev_reasoning=truncated_prev_reasoning,
+ search_query=search_query,
+ document="\n".join(kb_prompt(kbinfos, 512))
+ ),
+ [{"role": "user",
+ "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
+ {"temperature": 0.7}):
+ ans = re.sub(r".*", "", ans, flags=re.DOTALL)
+ if not ans:
+ continue
+ summary_think = ans
+ yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None}
+
+ all_reasoning_steps.append(summary_think)
+ msg_hisotry.append(
+ {"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
+ think += rm_result_tags(summary_think)
+ logging.info(f"[THINK]Summary: {ii}. {summary_think}")
yield think + ""
diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py
index beabe61eb..c98597bce 100644
--- a/rag/nlp/__init__.py
+++ b/rag/nlp/__init__.py
@@ -17,7 +17,6 @@
import logging
import random
from collections import Counter
-from typing import Optional
from rag.utils import num_tokens_from_string
from . import rag_tokenizer
@@ -604,9 +603,6 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
return cks, images
-def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]:
+def extract_between(text: str, start_tag: str, end_tag: str) -> list[str]:
pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
- matches = re.findall(pattern, text, flags=re.DOTALL)
- if matches:
- return matches[-1].strip()
- return None
\ No newline at end of file
+ return re.findall(pattern, text, flags=re.DOTALL)