diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py
index 835ed80ca..3b68ca1b4 100644
--- a/api/db/services/dialog_service.py
+++ b/api/db/services/dialog_service.py
@@ -17,6 +17,8 @@ import logging
import binascii
import os
import json
+import time
+
import json_repair
import re
from collections import defaultdict
@@ -33,6 +35,7 @@ from api.db.services.llm_service import TenantLLMService, LLMBundle
from api import settings
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
from rag.app.resume import forbidden_select_fields4resume
+from rag.nlp import extract_between
from rag.nlp.search import index_name
from rag.settings import TAG_FLD
from rag.utils import rmSpace, num_tokens_from_string, encoder
@@ -135,7 +138,7 @@ def kb_prompt(kbinfos, max_tokens):
knowledges = []
for nm, cks_meta in doc2chunks.items():
txt = f"Document: {nm} \n"
- for k,v in cks_meta["meta"].items():
+ for k, v in cks_meta["meta"].items():
txt += f"{k}: {v}\n"
txt += "Relevant fragments as following:\n"
for i, chunk in enumerate(cks_meta["chunks"], 1):
@@ -246,9 +249,11 @@ def chat(dialog, messages, stream=True, **kwargs):
bind_reranker_ts = timer()
generate_keyword_ts = bind_reranker_ts
+ thought = ""
+ kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
- kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
+ knowledges = []
else:
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
@@ -256,28 +261,37 @@ def chat(dialog, messages, stream=True, **kwargs):
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
- kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
- dialog.similarity_threshold,
- dialog.vector_similarity_weight,
- doc_ids=attachments,
- top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
- rank_feature=label_question(" ".join(questions), kbs)
- )
- if prompt_config.get("use_kg"):
- ck = settings.kg_retrievaler.retrieval(" ".join(questions),
- tenant_ids,
- dialog.kb_ids,
- embd_mdl,
- LLMBundle(dialog.tenant_id, LLMType.CHAT))
- if ck["content_with_weight"]:
- kbinfos["chunks"].insert(0, ck)
+ knowledges = []
+ if prompt_config.get("reasoning", False):
+ for think in reasoning(kbinfos, " ".join(questions), chat_mdl, embd_mdl, tenant_ids, dialog.kb_ids, MAX_SEARCH_LIMIT=3):
+ if isinstance(think, str):
+ thought = think
+ knowledges = [t for t in think.split("\n") if t]
+ else:
+ yield think
+ else:
+ kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
+ dialog.similarity_threshold,
+ dialog.vector_similarity_weight,
+ doc_ids=attachments,
+ top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
+ rank_feature=label_question(" ".join(questions), kbs)
+ )
+ if prompt_config.get("use_kg"):
+ ck = settings.kg_retrievaler.retrieval(" ".join(questions),
+ tenant_ids,
+ dialog.kb_ids,
+ embd_mdl,
+ LLMBundle(dialog.tenant_id, LLMType.CHAT))
+ if ck["content_with_weight"]:
+ kbinfos["chunks"].insert(0, ck)
- retrieval_ts = timer()
+ knowledges = kb_prompt(kbinfos, max_tokens)
- knowledges = kb_prompt(kbinfos, max_tokens)
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
+ retrieval_ts = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
@@ -302,9 +316,12 @@ def chat(dialog, messages, stream=True, **kwargs):
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
- finish_chat_ts = timer()
-
refs = []
+ ans = answer.split("")
+ think = ""
+ if len(ans) == 2:
+ think = ans[0] + ""
+ answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retriever.insert_citations(answer,
[ck["content_ltks"]
@@ -342,22 +359,24 @@ def chat(dialog, messages, stream=True, **kwargs):
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
- return {"answer": answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt)}
+ return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
if stream:
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
+ if thought:
+ ans = re.sub(r".*", "", ans, flags=re.DOTALL)
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
- yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
+ yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
- yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
- yield decorate_answer(answer)
+ yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
+ yield decorate_answer(thought+answer)
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
user_content = msg[-1].get("content", "[content not available]")
@@ -798,3 +817,191 @@ Output:
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
+
+
+def reasoning(chunk_info: dict, question: str, chat_mdl: LLMBundle, embd_mdl: LLMBundle,
+ tenant_ids: list[str], kb_ids: list[str], MAX_SEARCH_LIMIT: int = 3,
+ top_n: int = 5, similarity_threshold: float = 0.4, vector_similarity_weight: float = 0.3):
+ BEGIN_SEARCH_QUERY = "<|begin_search_query|>"
+ END_SEARCH_QUERY = "<|end_search_query|>"
+ BEGIN_SEARCH_RESULT = "<|begin_search_result|>"
+ END_SEARCH_RESULT = "<|end_search_result|>"
+
+ def rm_query_tags(line):
+ pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY)
+ return re.sub(pattern, "", line)
+
+ def rm_result_tags(line):
+ pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT)
+ return re.sub(pattern, "", line)
+
+ reason_prompt = (
+ "You are a reasoning assistant with the ability to perform dataset searches to help "
+ "you answer the user's question accurately. You have special tools:\n\n"
+ f"- To perform a search: write {BEGIN_SEARCH_QUERY} your query here {END_SEARCH_QUERY}.\n"
+ f"Then, the system will search and analyze relevant content, then provide you with helpful information in the format {BEGIN_SEARCH_RESULT} ...search results... {END_SEARCH_RESULT}.\n\n"
+ f"You can repeat the search process multiple times if necessary. The maximum number of search attempts is limited to {MAX_SEARCH_LIMIT}.\n\n"
+ "Once you have all the information you need, continue your reasoning.\n\n"
+ "-- Example --\n"
+ "Question: \"Find the minimum number of vertices in a Steiner tree that includes all specified vertices in a given tree.\"\n"
+ "Assistant thinking steps:\n"
+ "- I need to understand what a Steiner tree is and how to compute the minimum number of vertices required to include all specified vertices in a given tree.\n\n"
+ "Assistant:\n"
+ f"{BEGIN_SEARCH_QUERY}Minimum Steiner Tree problem in trees{END_SEARCH_QUERY}\n\n"
+ "(System returns processed information from relevant web pages)\n\n"
+ "Assistant continues reasoning with the new information...\n\n"
+ "**Remember**:\n"
+ f"- You have a dataset to search, so you just provide a proper search query.\n"
+ f"- Use {BEGIN_SEARCH_QUERY} to request a dataset search and end with {END_SEARCH_QUERY}.\n"
+ "- The language of query MUST be as the same as 'Question' or 'search result'.\n"
+ "- When done searching, continue your reasoning.\n\n"
+ 'Please answer the following question. You should think step by step to solve it.\n\n'
+ )
+
+ relevant_extraction_prompt = """**Task Instruction:**
+
+ You are tasked with reading and analyzing web pages based on the following inputs: **Previous Reasoning Steps**, **Current Search Query**, and **Searched Web Pages**. Your objective is to extract relevant and helpful information for **Current Search Query** from the **Searched Web Pages** and seamlessly integrate this information into the **Previous Reasoning Steps** to continue reasoning for the original question.
+
+ **Guidelines:**
+
+ 1. **Analyze the Searched Web Pages:**
+ - Carefully review the content of each searched web page.
+ - Identify factual information that is relevant to the **Current Search Query** and can aid in the reasoning process for the original question.
+
+ 2. **Extract Relevant Information:**
+ - Select the information from the Searched Web Pages that directly contributes to advancing the **Previous Reasoning Steps**.
+ - Ensure that the extracted information is accurate and relevant.
+
+ 3. **Output Format:**
+ - **If the web pages provide helpful information for current search query:** Present the information beginning with `**Final Information**` as shown below.
+ - The language of query **MUST BE** as the same as 'Search Query' or 'Web Pages'.\n"
+ **Final Information**
+
+ [Helpful information]
+
+ - **If the web pages do not provide any helpful information for current search query:** Output the following text.
+
+ **Final Information**
+
+ No helpful information found.
+
+ **Inputs:**
+ - **Previous Reasoning Steps:**
+ {prev_reasoning}
+
+ - **Current Search Query:**
+ {search_query}
+
+ - **Searched Web Pages:**
+ {document}
+
+ """
+
+ executed_search_queries = []
+ msg_hisotry = [{"role": "user", "content": f'Question:\n{question}\n\n'}]
+ all_reasoning_steps = []
+ think = ""
+ for ii in range(MAX_SEARCH_LIMIT + 1):
+ if ii == MAX_SEARCH_LIMIT - 1:
+ summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
+ yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None}
+ all_reasoning_steps.append(summary_think)
+ msg_hisotry.append({"role": "assistant", "content": summary_think})
+ break
+
+ query_think = ""
+ if msg_hisotry[-1]["role"] != "user":
+ msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information...\n"})
+ for ans in chat_mdl.chat_streamly(reason_prompt, msg_hisotry, {"temperature": 0.7}):
+ ans = re.sub(r".*", "", ans, flags=re.DOTALL)
+ if not ans:
+ continue
+ query_think = ans
+ yield {"answer": think + rm_query_tags(query_think) + "", "reference": {}, "audio_binary": None}
+
+ think += rm_query_tags(query_think)
+ all_reasoning_steps.append(query_think)
+ msg_hisotry.append({"role": "assistant", "content": query_think})
+ search_query = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
+ if not search_query:
+ if ii > 0:
+ break
+ search_query = question
+ txt = f"\n{BEGIN_SEARCH_QUERY}{question}{END_SEARCH_QUERY}\n\n"
+ think += txt
+ msg_hisotry[-1]["content"] += txt
+
+ logging.info(f"[THINK]Query: {ii}. {search_query}")
+ think += f"\n\n> {ii+1}. {search_query}\n\n"
+ yield {"answer": think + "", "reference": {}, "audio_binary": None}
+
+ summary_think = ""
+ # The search query has been searched in previous steps.
+ if search_query in executed_search_queries:
+ summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
+ yield {"answer": think + summary_think + "", "reference": {}, "audio_binary": None}
+ all_reasoning_steps.append(summary_think)
+ msg_hisotry.append({"role": "assistant", "content": summary_think})
+ think += summary_think
+ continue
+
+ truncated_prev_reasoning = ""
+ for i, step in enumerate(all_reasoning_steps):
+ truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
+
+ prev_steps = truncated_prev_reasoning.split('\n\n')
+ if len(prev_steps) <= 5:
+ truncated_prev_reasoning = '\n\n'.join(prev_steps)
+ else:
+ truncated_prev_reasoning = ''
+ for i, step in enumerate(prev_steps):
+ if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
+ truncated_prev_reasoning += step + '\n\n'
+ else:
+ if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
+ truncated_prev_reasoning += '...\n\n'
+ truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
+
+ kbinfos = settings.retrievaler.retrieval(search_query, embd_mdl, tenant_ids, kb_ids, 1, top_n,
+ similarity_threshold,
+ vector_similarity_weight
+ )
+ # Merge chunk info for citations
+ if not chunk_info["chunks"]:
+ for k in chunk_info.keys():
+ chunk_info[k] = kbinfos[k]
+ else:
+ cids = [c["chunk_id"] for c in chunk_info["chunks"]]
+ for c in kbinfos["chunks"]:
+ if c["chunk_id"] in cids:
+ continue
+ chunk_info["chunks"].append(c)
+ dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
+ for d in kbinfos["doc_aggs"]:
+ if d["doc_id"] in dids:
+ continue
+ chunk_info["doc_aggs"].append(d)
+
+ think += "\n\n"
+ for ans in chat_mdl.chat_streamly(
+ relevant_extraction_prompt.format(
+ prev_reasoning=truncated_prev_reasoning,
+ search_query=search_query,
+ document="\n".join(kb_prompt(kbinfos, 512))
+ ),
+ [{"role": "user",
+ "content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
+ {"temperature": 0.7}):
+ ans = re.sub(r".*", "", ans, flags=re.DOTALL)
+ if not ans:
+ continue
+ summary_think = ans
+ yield {"answer": think + rm_result_tags(summary_think) + "", "reference": {}, "audio_binary": None}
+
+ all_reasoning_steps.append(summary_think)
+ msg_hisotry.append(
+ {"role": "assistant", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
+ think += rm_result_tags(summary_think)
+ logging.info(f"[THINK]Summary: {ii}. {summary_think}")
+
+ yield think + ""
diff --git a/rag/nlp/__init__.py b/rag/nlp/__init__.py
index 5e288a9b1..beabe61eb 100644
--- a/rag/nlp/__init__.py
+++ b/rag/nlp/__init__.py
@@ -17,6 +17,7 @@
import logging
import random
from collections import Counter
+from typing import Optional
from rag.utils import num_tokens_from_string
from . import rag_tokenizer
@@ -601,3 +602,11 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
add_chunk(sec, image, '')
return cks, images
+
+
+def extract_between(text: str, start_tag: str, end_tag: str) -> Optional[str]:
+ pattern = re.escape(start_tag) + r"(.*?)" + re.escape(end_tag)
+ matches = re.findall(pattern, text, flags=re.DOTALL)
+ if matches:
+ return matches[-1].strip()
+ return None
\ No newline at end of file
diff --git a/rag/nlp/search.py b/rag/nlp/search.py
index 86b02d78d..86416cd49 100644
--- a/rag/nlp/search.py
+++ b/rag/nlp/search.py
@@ -15,7 +15,6 @@
#
import logging
import re
-import json
from dataclasses import dataclass
from rag.settings import TAG_FLD, PAGERANK_FLD
@@ -259,7 +258,7 @@ class Dealer:
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
for i in search_res.ids:
nor, denor = 0, 0
- for t, sc in json.loads(search_res.field[i].get(TAG_FLD, "{}")).items():
+ for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
if t in query_rfea:
nor += query_rfea[t] * sc
denor += sc * sc