Refa: more fallbacks for bad citation format (#7710)

### What problem does this PR solve?

More fallbacks for bad citation format

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
This commit is contained in:
Yongteng Lei 2025-05-19 19:34:05 +08:00 committed by GitHub
parent b908c33464
commit e8e2a95165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,11 +14,11 @@
# limitations under the License.
#
import binascii
from datetime import datetime
import logging
import re
import time
from copy import deepcopy
from datetime import datetime
from functools import partial
from timeit import default_timer as timer
@ -36,8 +36,7 @@ from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \
cross_languages
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
from rag.utils import num_tokens_from_string, rmSpace
from rag.utils.tavily_conn import Tavily
@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs):
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict):
max_index = len(kbinfos["chunks"])
def safe_add(i):
if 0 <= i < max_index:
idx.add(i)
return True
return False
def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0):
nonlocal answer
for match in re.finditer(pattern, answer, flags=flags):
try:
i = int(match.group(group_index))
if safe_add(i):
answer = answer.replace(match.group(0), repl(i))
except Exception:
continue
find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12)
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$
find_and_replace(r"\$(\d+)\$") # $12$
find_and_replace(r"#(\d+)\$\$") # #12$$
find_and_replace(r"##(\d+)\$") # ##12$
find_and_replace(r"##(\d+)#{2,}") # ##12###
find_and_replace(r"【(\d+)】") # 【12】
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12
return answer, idx
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
@ -331,15 +363,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if i < len(kbinfos["chunks"]):
idx.add(i)
# handle (ID: 1), ID: 2 etc.
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
full_match = match.group(0)
id = match.group(1) or match.group(2)
if id:
i = int(id)
if i < len(kbinfos["chunks"]):
idx.add(i)
answer = answer.replace(full_match, f"##{i}$$")
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@ -599,3 +623,4 @@ def ask(question, kb_ids, tenant_id):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)