mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-19 04:30:01 +08:00
Fix: add fallback for bad citation output (#7014)
### What problem does this PR solve? Add fallback for bad citation output. #6948 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
b1fa5a0754
commit
7a34159737
@ -271,8 +271,10 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if len(ans) == 2:
|
||||
think = ans[0] + "</think>"
|
||||
answer = ans[1]
|
||||
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
|
||||
idx = set([])
|
||||
if not re.search(r"##[0-9]+\$\$", answer):
|
||||
answer, idx = retriever.insert_citations(
|
||||
answer,
|
||||
@ -283,12 +285,21 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
vtweight=dialog.vector_similarity_weight,
|
||||
)
|
||||
else:
|
||||
idx = set([])
|
||||
for r in re.finditer(r"##([0-9]+)\$\$", answer):
|
||||
i = int(r.group(1))
|
||||
for match in re.finditer(r"##([0-9]+)\$\$", answer):
|
||||
i = int(match.group(1))
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
|
||||
# handle (ID: 1), ID: 2 etc.
|
||||
for match in re.finditer(r"\(\s*ID:\s*(\d+)\s*\)|ID[: ]+\s*(\d+)", answer):
|
||||
full_match = match.group(0)
|
||||
id = match.group(1) or match.group(2)
|
||||
if id:
|
||||
i = int(id)
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
answer = answer.replace(full_match, f"##{i}$$")
|
||||
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
if not recall_docs:
|
||||
|
@ -31,7 +31,8 @@ def chunks_format(reference):
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
return [{
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
@ -43,7 +44,9 @@ def chunks_format(reference):
|
||||
"similarity": chunk.get("similarity"),
|
||||
"vector_similarity": chunk.get("vector_similarity"),
|
||||
"term_similarity": chunk.get("term_similarity"),
|
||||
} for chunk in reference.get("chunks", [])]
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
def llm_id2llm_type(llm_id):
|
||||
@ -63,8 +66,7 @@ def message_fit_in(msg, max_length=4000):
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg:
|
||||
tks_cnts.append(
|
||||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
@ -137,6 +139,10 @@ def citation_prompt():
|
||||
- Inserts CITATIONS in format '##i$$ ##j$$' where i,j are the ID of the content you are citing and encapsulated with '##' and '$$'.
|
||||
- Inserts the CITATION symbols at the end of a sentence, AND NO MORE than 4 citations.
|
||||
- DO NOT insert CITATION in the answer if the content is not from retrieved chunks.
|
||||
- DO NOT use standalone Document IDs (e.g., '#ID#').
|
||||
- Under NO circumstances any other citation styles or formats (e.g., '~~i==', '[i]', '(i)', etc.) be used.
|
||||
- Citations ALWAYS the '##i$$' format.
|
||||
- Any failure to adhere to the above rules, including but not limited to incorrect formatting, use of prohibited styles, or unsupported citations, will be considered a error, should skip adding Citation for this sentence.
|
||||
|
||||
--- Example START ---
|
||||
<SYSTEM>: Here is the knowledge base:
|
||||
@ -185,10 +191,7 @@ Requirements:
|
||||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
@ -215,10 +218,7 @@ Requirements:
|
||||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
@ -345,10 +345,7 @@ Output:
|
||||
{content}
|
||||
|
||||
"""
|
||||
msg = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": "Output: "}
|
||||
]
|
||||
msg = [{"role": "system", "content": prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
@ -361,8 +358,8 @@ Output:
|
||||
return json_repair.loads(kwd)
|
||||
except json_repair.JSONDecodeError:
|
||||
try:
|
||||
result = kwd.replace(prompt[:-1], '').replace('user', '').replace('model', '').strip()
|
||||
result = '{' + result.split('{')[1].split('}')[0] + '}'
|
||||
result = kwd.replace(prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
return json_repair.loads(result)
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user