mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-13 05:58:59 +08:00
add prompt to message (#2099)
### What problem does this PR solve? #2098 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
6b7c028578
commit
6d3e3e4e3c
@ -140,7 +140,8 @@ def completion():
|
|||||||
if not conv.reference:
|
if not conv.reference:
|
||||||
conv.reference.append(ans["reference"])
|
conv.reference.append(ans["reference"])
|
||||||
else: conv.reference[-1] = ans["reference"]
|
else: conv.reference[-1] = ans["reference"]
|
||||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
conv.message[-1] = {"role": "assistant", "content": ans["answer"],
|
||||||
|
"id": message_id, "prompt": ans.get("prompt", "")}
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
|
@ -179,6 +179,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
for m in messages if m["role"] != "system"])
|
for m in messages if m["role"] != "system"])
|
||||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
||||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||||
|
prompt = msg[0]["content"]
|
||||||
|
|
||||||
if "max_tokens" in gen_conf:
|
if "max_tokens" in gen_conf:
|
||||||
gen_conf["max_tokens"] = min(
|
gen_conf["max_tokens"] = min(
|
||||||
@ -186,7 +187,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
max_tokens - used_token_count)
|
max_tokens - used_token_count)
|
||||||
|
|
||||||
def decorate_answer(answer):
|
def decorate_answer(answer):
|
||||||
nonlocal prompt_config, knowledges, kwargs, kbinfos
|
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt
|
||||||
refs = []
|
refs = []
|
||||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||||
answer, idx = retr.insert_citations(answer,
|
answer, idx = retr.insert_citations(answer,
|
||||||
@ -210,17 +211,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs, "prompt": prompt}
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf):
|
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}, "prompt": prompt}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(
|
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||||
msg[0]["content"], msg[1:], gen_conf)
|
|
||||||
chat_logger.info("User: {}|Assistant: {}".format(
|
chat_logger.info("User: {}|Assistant: {}".format(
|
||||||
msg[-1]["content"], answer))
|
msg[-1]["content"], answer))
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
@ -334,7 +334,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|||||||
chat_logger.warning("SQL missing field: " + sql)
|
chat_logger.warning("SQL missing field: " + sql)
|
||||||
return {
|
return {
|
||||||
"answer": "\n".join([clmns, line, rows]),
|
"answer": "\n".join([clmns, line, rows]),
|
||||||
"reference": {"chunks": [], "doc_aggs": []}
|
"reference": {"chunks": [], "doc_aggs": []},
|
||||||
|
"prompt": sys_prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
docid_idx = list(docid_idx)[0]
|
docid_idx = list(docid_idx)[0]
|
||||||
@ -348,7 +349,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|||||||
"answer": "\n".join([clmns, line, rows]),
|
"answer": "\n".join([clmns, line, rows]),
|
||||||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
||||||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
||||||
doc_aggs.items()]}
|
doc_aggs.items()]},
|
||||||
|
"prompt": sys_prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user