Refa: more fallbacks for bad citation format (#7710)

### What problem does this PR solve?

More fallbacks for bad citation format

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Yongteng Lei 2025-05-19 19:34:05 +08:00 committed by GitHub
parent b908c33464
commit e8e2a95165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,11 +14,11 @@
# limitations under the License. # limitations under the License.
# #
import binascii import binascii
from datetime import datetime
import logging import logging
import re import re
import time import time
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from functools import partial from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer
@ -36,8 +36,7 @@ from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \ from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
cross_languages
from rag.utils import num_tokens_from_string, rmSpace from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily from rag.utils.tavily_conn import Tavily
@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs):
if "max_tokens" in gen_conf: if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict):
max_index = len(kbinfos["chunks"])
def safe_add(i):
if 0 <= i < max_index:
idx.add(i)
return True
return False
def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0):
nonlocal answer
for match in re.finditer(pattern, answer, flags=flags):
try:
i = int(match.group(group_index))
if safe_add(i):
answer = answer.replace(match.group(0), repl(i))
except Exception:
continue
find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12)
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$
find_and_replace(r"\$(\d+)\$") # $12$
find_and_replace(r"#(\d+)\$\$") # #12$$
find_and_replace(r"##(\d+)\$") # ##12$
find_and_replace(r"##(\d+)#{2,}") # ##12###
find_and_replace(r"【(\d+)】") # 【12】
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12
return answer, idx
def decorate_answer(answer): def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
@ -331,15 +363,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if i < len(kbinfos["chunks"]): if i < len(kbinfos["chunks"]):
idx.add(i) idx.add(i)
# handle (ID: 1), ID: 2 etc. answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
full_match = match.group(0)
id = match.group(1) or match.group(2)
if id:
i = int(id)
if i < len(kbinfos["chunks"]):
idx.add(i)
answer = answer.replace(full_match, f"##{i}$$")
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@ -502,7 +526,7 @@ Please write the SQL, only SQL, without any other explanations or text.
# compose Markdown table # compose Markdown table
columns = ( columns = (
"|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
) )
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "")
@ -598,4 +622,5 @@ def ask(question, kb_ids, tenant_id):
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)