From e8e2a95165f5fdd5304571a514cd07de771cf87d Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Mon, 19 May 2025 19:34:05 +0800 Subject: [PATCH] Refa: more fallbacks for bad citation format (#7710) ### What problem does this PR solve? More fallbacks for bad citation format ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --- api/db/services/dialog_service.py | 53 +++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 1d6ed6ca0..607194db5 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -14,11 +14,11 @@ # limitations under the License. # import binascii -from datetime import datetime import logging import re import time from copy import deepcopy +from datetime import datetime from functools import partial from timeit import default_timer as timer @@ -36,8 +36,7 @@ from api.utils import current_timestamp, datetime_format from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name -from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \ - cross_languages +from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in from rag.utils import num_tokens_from_string, rmSpace from rag.utils.tavily_conn import Tavily @@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs): if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) + def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict): + max_index = len(kbinfos["chunks"]) + + def safe_add(i): + if 0 <= i < max_index: + idx.add(i) + return True + return False + + def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0): + nonlocal answer + for match in re.finditer(pattern, answer, flags=flags): + try: + i = int(match.group(group_index)) + if safe_add(i): + answer = answer.replace(match.group(0), repl(i)) + except Exception: + continue + + find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12) + find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12 + find_and_replace(r"\$\$(\d+)\$\$") # $$12$$ + find_and_replace(r"\$\[(\d+)\]\$") # $[12]$ + find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$ + find_and_replace(r"\$(\d+)\$") # $12$ + find_and_replace(r"#(\d+)\$\$") # #12$$ + find_and_replace(r"##(\d+)\$") # ##12$ + find_and_replace(r"##(\d+)#{2,}") # ##12### + find_and_replace(r"【(\d+)】") # 【12】 + find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12 + + return answer, idx + def decorate_answer(answer): nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer @@ -331,15 +363,7 @@ def chat(dialog, messages, stream=True, **kwargs): if i < len(kbinfos["chunks"]): idx.add(i) - # handle (ID: 1), ID: 2 etc. - for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer): - full_match = match.group(0) - id = match.group(1) or match.group(2) - if id: - i = int(id) - if i < len(kbinfos["chunks"]): - idx.add(i) - answer = answer.replace(full_match, f"##{i}$$") + answer, idx = repair_bad_citation_formats(answer, kbinfos, idx) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] @@ -502,7 +526,7 @@ Please write the SQL, only SQL, without any other explanations or text. # compose Markdown table columns = ( - "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") + "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ("|Source|" if docid_idx and docid_idx else "|") ) line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") @@ -598,4 +622,5 @@ def ask(question, kb_ids, tenant_id): for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} - yield decorate_answer(answer) \ No newline at end of file + yield decorate_answer(answer) +