mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-10 23:49:02 +08:00
Refa: more fallbacks for bad citation format (#7710)
### What problem does this PR solve? More fallbacks for bad citation format ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
This commit is contained in:
parent
b908c33464
commit
e8e2a95165
@ -14,11 +14,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import binascii
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from timeit import default_timer as timer
|
||||
|
||||
@ -36,8 +36,7 @@ from api.utils import current_timestamp, datetime_format
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts import chunks_format, citation_prompt, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in, \
|
||||
cross_languages
|
||||
from rag.prompts import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, llm_id2llm_type, message_fit_in
|
||||
from rag.utils import num_tokens_from_string, rmSpace
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
@ -303,6 +302,39 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)
|
||||
|
||||
def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: dict):
|
||||
max_index = len(kbinfos["chunks"])
|
||||
|
||||
def safe_add(i):
|
||||
if 0 <= i < max_index:
|
||||
idx.add(i)
|
||||
return True
|
||||
return False
|
||||
|
||||
def find_and_replace(pattern, group_index=1, repl=lambda i: f"##{i}$$", flags=0):
|
||||
nonlocal answer
|
||||
for match in re.finditer(pattern, answer, flags=flags):
|
||||
try:
|
||||
i = int(match.group(group_index))
|
||||
if safe_add(i):
|
||||
answer = answer.replace(match.group(0), repl(i))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
find_and_replace(r"\(\s*ID:\s*(\d+)\s*\)") # (ID: 12)
|
||||
find_and_replace(r"ID[: ]+(\d+)") # ID: 12, ID 12
|
||||
find_and_replace(r"\$\$(\d+)\$\$") # $$12$$
|
||||
find_and_replace(r"\$\[(\d+)\]\$") # $[12]$
|
||||
find_and_replace(r"\$\$(\d+)\${2,}") # $$12$$$$
|
||||
find_and_replace(r"\$(\d+)\$") # $12$
|
||||
find_and_replace(r"#(\d+)\$\$") # #12$$
|
||||
find_and_replace(r"##(\d+)\$") # ##12$
|
||||
find_and_replace(r"##(\d+)#{2,}") # ##12###
|
||||
find_and_replace(r"【(\d+)】") # 【12】
|
||||
find_and_replace(r"ref\s*(\d+)", flags=re.IGNORECASE) # ref12, ref 12, REF 12
|
||||
|
||||
return answer, idx
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer
|
||||
|
||||
@ -331,15 +363,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
|
||||
# handle (ID: 1), ID: 2 etc.
|
||||
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
|
||||
full_match = match.group(0)
|
||||
id = match.group(1) or match.group(2)
|
||||
if id:
|
||||
i = int(id)
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
answer = answer.replace(full_match, f"##{i}$$")
|
||||
answer, idx = repair_bad_citation_formats(answer, kbinfos, idx)
|
||||
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
@ -599,3 +623,4 @@ def ask(question, kb_ids, tenant_id):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user