Fix: add fallback for bad citation output (#7014)

### What problem does this PR solve?

Add fallback for bad citation output. #6948

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Yongteng Lei 2025-04-15 09:33:53 +08:00 committed by GitHub
parent b1fa5a0754
commit 7a34159737
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 35 deletions

View File

@ -271,8 +271,10 @@ def chat(dialog, messages, stream=True, **kwargs):
if len(ans) == 2: if len(ans) == 2:
think = ans[0] + "</think>" think = ans[0] + "</think>"
answer = ans[1] answer = ans[1]
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL) answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
idx = set([])
if not re.search(r"##[0-9]+\$\$", answer): if not re.search(r"##[0-9]+\$\$", answer):
answer, idx = retriever.insert_citations( answer, idx = retriever.insert_citations(
answer, answer,
@ -283,12 +285,21 @@ def chat(dialog, messages, stream=True, **kwargs):
vtweight=dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight,
) )
else: else:
idx = set([]) for match in re.finditer(r"##([0-9]+)\$\$", answer):
for r in re.finditer(r"##([0-9]+)\$\$", answer): i = int(match.group(1))
i = int(r.group(1))
if i < len(kbinfos["chunks"]): if i < len(kbinfos["chunks"]):
idx.add(i) idx.add(i)
# handle (ID: 1), ID: 2 etc.
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
full_match = match.group(0)
id = match.group(1) or match.group(2)
if id:
i = int(id)
if i < len(kbinfos["chunks"]):
idx.add(i)
answer = answer.replace(full_match, f"##{i}$$")
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
if not recall_docs: if not recall_docs:

View File

@ -31,19 +31,22 @@ def chunks_format(reference):
def get_value(d, k1, k2): def get_value(d, k1, k2):
return d.get(k1, d.get(k2)) return d.get(k1, d.get(k2))
return [{ return [
"id": get_value(chunk, "chunk_id", "id"), {
"content": get_value(chunk, "content", "content_with_weight"), "id": get_value(chunk, "chunk_id", "id"),
"document_id": get_value(chunk, "doc_id", "document_id"), "content": get_value(chunk, "content", "content_with_weight"),
"document_name": get_value(chunk, "docnm_kwd", "document_name"), "document_id": get_value(chunk, "doc_id", "document_id"),
"dataset_id": get_value(chunk, "kb_id", "dataset_id"), "document_name": get_value(chunk, "docnm_kwd", "document_name"),
"image_id": get_value(chunk, "image_id", "img_id"), "dataset_id": get_value(chunk, "kb_id", "dataset_id"),
"positions": get_value(chunk, "positions", "position_int"), "image_id": get_value(chunk, "image_id", "img_id"),
"url": chunk.get("url"), "positions": get_value(chunk, "positions", "position_int"),
"similarity": chunk.get("similarity"), "url": chunk.get("url"),
"vector_similarity": chunk.get("vector_similarity"), "similarity": chunk.get("similarity"),
"term_similarity": chunk.get("term_similarity"), "vector_similarity": chunk.get("vector_similarity"),
} for chunk in reference.get("chunks", [])] "term_similarity": chunk.get("term_similarity"),
}
for chunk in reference.get("chunks", [])
]
def llm_id2llm_type(llm_id): def llm_id2llm_type(llm_id):
@ -63,8 +66,7 @@ def message_fit_in(msg, max_length=4000):
nonlocal msg nonlocal msg
tks_cnts = [] tks_cnts = []
for m in msg: for m in msg:
tks_cnts.append( tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
total = 0 total = 0
for m in tks_cnts: for m in tks_cnts:
total += m["count"] total += m["count"]
@ -86,12 +88,12 @@ def message_fit_in(msg, max_length=4000):
ll2 = num_tokens_from_string(msg_[-1]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8: if ll / (ll + ll2) > 0.8:
m = msg_[0]["content"] m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2]) m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m msg[0]["content"] = m
return max_length, msg return max_length, msg
m = msg_[-1]["content"] m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[:max_length - ll2]) m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[-1]["content"] = m msg[-1]["content"] = m
return max_length, msg return max_length, msg
@ -107,7 +109,7 @@ def kb_prompt(kbinfos, max_tokens):
chunks_num += 1 chunks_num += 1
if max_tokens * 0.97 < used_token_count: if max_tokens * 0.97 < used_token_count:
knowledges = knowledges[:i] knowledges = knowledges[:i]
logging.warning(f"Not all the retrieval into prompt: {i+1}/{len(knowledges)}") logging.warning(f"Not all the retrieval into prompt: {i + 1}/{len(knowledges)}")
break break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]]) docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
@ -137,6 +139,10 @@ def citation_prompt():
- Inserts CITATIONS in format '##i$$ ##j$$' where i,j are the ID of the content you are citing and encapsulated with '##' and '$$'. - Inserts CITATIONS in format '##i$$ ##j$$' where i,j are the ID of the content you are citing and encapsulated with '##' and '$$'.
- Inserts the CITATION symbols at the end of a sentence, AND NO MORE than 4 citations. - Inserts the CITATION symbols at the end of a sentence, AND NO MORE than 4 citations.
- DO NOT insert CITATION in the answer if the content is not from retrieved chunks. - DO NOT insert CITATION in the answer if the content is not from retrieved chunks.
- DO NOT use standalone Document IDs (e.g., '#ID#').
- Under NO circumstances any other citation styles or formats (e.g., '~~i==', '[i]', '(i)', etc.) be used.
- Citations ALWAYS the '##i$$' format.
- Any failure to adhere to the above rules, including but not limited to incorrect formatting, use of prohibited styles, or unsupported citations, will be considered a error, should skip adding Citation for this sentence.
--- Example START --- --- Example START ---
<SYSTEM>: Here is the knowledge base: <SYSTEM>: Here is the knowledge base:
@ -185,10 +191,7 @@ Requirements:
{content} {content}
""" """
msg = [ msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
@ -215,10 +218,7 @@ Requirements:
{content} {content}
""" """
msg = [ msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
@ -345,10 +345,7 @@ Output:
{content} {content}
""" """
msg = [ msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
{"role": "system", "content": prompt},
{"role": "user", "content": "Output: "}
]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5}) kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
@ -361,8 +358,8 @@ Output:
return json_repair.loads(kwd) return json_repair.loads(kwd)
except json_repair.JSONDecodeError: except json_repair.JSONDecodeError:
try: try:
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip() result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
result = '{' + result.split('{')[1].split('}')[0] + '}' result = "{" + result.split("{")[1].split("}")[0] + "}"
return json_repair.loads(result) return json_repair.loads(result)
except Exception as e: except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}") logging.exception(f"JSON parsing error: {result} -> {e}")