add self-rag (#1070)

### What problem does this PR solve?

#1069 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh 2024-06-06 11:13:39 +08:00 committed by GitHub
parent 72c6784ff8
commit 4454ba7a1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 234 additions and 38 deletions

View File

@ -198,15 +198,18 @@ def completion():
else: conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
def rename_field(ans):
for chunk_i in ans['reference'].get('chunks', []):
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, True, **req):
fillin_conv(ans)
for chunk_i in ans['reference'].get('chunks', []):
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
rename_field(rename_field)
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
@ -554,23 +557,24 @@ def completion_faq():
"content": ""
}
]
for ans in chat(dia, msg, stream=False, **req):
# answer = ans
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = MINIO.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
except Exception as e:
return server_error_response(e)
ans = ""
for a in chat(dia, msg, stream=False, **req):
ans = a
break
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
fillin_conv(ans)
API4ConversationService.append_message(conv.id, conv.to_dict())
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
for chunk_idx in chunk_idxs[:1]:
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
try:
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
response = MINIO.get(bkt, nm)
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
data.append(data_type_picture)
except Exception as e:
return server_error_response(e)
response = {"code": 200, "msg": "success", "data": data}
return response

112
api/apps/canvas_app.py Normal file
View File

@ -0,0 +1,112 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from flask import request
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request
from graph.canvas import Canvas
@manager.route('/templates', methods=['GET'])
@login_required
def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
@manager.route('/list', methods=['GET'])
@login_required
def canvas_list():
return get_json_result(data=[c.to_dict() for c in UserCanvasService.query(user_id=current_user.id)])
@manager.route('/rm', methods=['POST'])
@validate_request("canvas_ids")
@login_required
def rm():
for i in request.json["canvas_ids"]:
UserCanvasService.delete_by_id(i)
return get_json_result(data=True)
@manager.route('/set', methods=['POST'])
@validate_request("dsl", "title")
@login_required
def save():
req = request.json
req["user_id"] = current_user.id
if not isinstance(req["dsl"], str):req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
try:
Canvas(req["dsl"])
except Exception as e:
return server_error_response(e)
req["dsl"] = json.loads(req["dsl"])
if "id" not in req:
req["id"] = get_uuid()
if not UserCanvasService.save(**req):
return server_error_response("Fail to save canvas.")
else:
UserCanvasService.update_by_id(req["id"], req)
return get_json_result(data=req)
@manager.route('/get/<canvas_id>', methods=['GET'])
@login_required
def get(canvas_id):
e, c = UserCanvasService.get_by_id(canvas_id)
if not e:
return server_error_response("canvas not found.")
return get_json_result(data=c.to_dict())
@manager.route('/run', methods=['POST'])
@validate_request("id", "dsl")
@login_required
def run():
req = request.json
if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False)
try:
canvas = Canvas(req["dsl"], current_user.id)
ans = canvas.run()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], dsl=req["dsl"])
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)
@manager.route('/reset', methods=['POST'])
@validate_request("canvas_id")
@login_required
def reset():
req = request.json
try:
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
canvas = Canvas(req["dsl"], current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request, Response, jsonify
from copy import deepcopy
from flask import request, Response
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
@ -121,7 +122,7 @@ def completion():
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
conv.message.append(msg[-1])
conv.message.append(deepcopy(msg[-1]))
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")

View File

@ -31,8 +31,8 @@ def set_dialog():
req = request.json
dialog_id = req.get("dialog_id")
name = req.get("name", "New Dialog")
icon = req.get("icon", "")
description = req.get("description", "A helpful Dialog")
icon = req.get("icon", "")
top_n = req.get("top_n", 6)
top_k = req.get("top_k", 1024)
rerank_id = req.get("rerank_id", "")
@ -92,7 +92,7 @@ def set_dialog():
"rerank_id": rerank_id,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"icon": icon,
"icon": icon
}
if not DialogService.save(**dia):
return get_data_error_result(retmsg="Fail to new a dialog!")

View File

@ -0,0 +1,26 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from api.db.db_models import DB, API4Conversation, APIToken, Dialog, CanvasTemplate, UserCanvas
from api.db.services.common_service import CommonService
class CanvasTemplateService(CommonService):
model = CanvasTemplate
class UserCanvasService(CommonService):
model = UserCanvas

View File

@ -23,6 +23,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.settings import chat_logger, retrievaler
from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.rag_tokenizer import is_chinese
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
@ -80,7 +81,8 @@ def chat(dialog, messages, stream=True, **kwargs):
if not llm:
raise LookupError("LLM(%s) not found" % dialog.llm_id)
max_tokens = 1024
else: max_tokens = llm[0].max_tokens
else:
max_tokens = llm[0].max_tokens
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
embd_nms = list(set([kb.embd_id for kb in kbs]))
if len(embd_nms) != 1:
@ -124,6 +126,16 @@ def chat(dialog, messages, stream=True, **kwargs):
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
#self-rag
if dialog.prompt_config.get("self_rag") and not relevant(dialog.tenant_id, dialog.llm_id, questions[-1], knowledges):
questions[-1] = rewrite(dialog.tenant_id, dialog.llm_id, questions[-1])
kbinfos = retrievaler.retrieval(" ".join(questions), embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n,
dialog.similarity_threshold,
dialog.vector_similarity_weight,
doc_ids=kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None,
top=1024, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
chat_logger.info(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
@ -136,7 +148,7 @@ def chat(dialog, messages, stream=True, **kwargs):
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": m["content"]}
for m in messages if m["role"] != "system"])
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
@ -150,9 +162,9 @@ def chat(dialog, messages, stream=True, **kwargs):
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
answer, idx = retrievaler.insert_citations(answer,
[ck["content_ltks"]
for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]],
[ck["vector"]
for ck in kbinfos["chunks"]],
for ck in kbinfos["chunks"]],
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
@ -166,7 +178,7 @@ def chat(dialog, messages, stream=True, **kwargs):
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
return {"answer": answer, "reference": refs}
@ -204,7 +216,7 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
def get_table():
nonlocal sys_prompt, user_promt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
"temperature": 0.06})
"temperature": 0.06})
print(user_promt, sql)
chat_logger.info(f"{question}”==>{user_promt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower())
@ -273,17 +285,19 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
# compose markdown table
clmns = "|" + "|".join([re.sub(r"(/.*|[^]+)", "", field_map.get(tbl["columns"][i]["name"],
tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
tbl["columns"][i]["name"])) for i in
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
("|------|" if docid_idx and docid_idx else "")
("|------|" if docid_idx and docid_idx else "")
rows = ["|" +
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
"|" for r in tbl["rows"]]
if quota:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
else:
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
if not docid_idx or not docnm_idx:
@ -303,5 +317,40 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
return {
"answer": "\n".join([clmns, line, rows]),
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()]}
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
doc_aggs.items()]}
}
def relevant(tenant_id, llm_id, question, contents: list):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are a grader assessing relevance of a retrieved document to a user question.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'.
"""
if not contents:return False
contents = "Documents: \n" + " - ".join(contents)
contents = f"Question: {question}\n" + contents
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
if ans.lower().find("yes") >= 0: return True
return False
def rewrite(tenant_id, llm_id, question):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question.
I can't retrieval relevant information from the knowledge base by using user's question directly.
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
changing the way of expression, translating the original question into another language (English/Chinese), etc.
And return 5 versions of question and one is from translation.
Just list the question. No other words are needed.
"""
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
return ans

View File

@ -1021,6 +1021,8 @@ class RAGFlowPdfParser:
self.page_cum_height = np.cumsum(self.page_cum_height)
assert len(self.page_cum_height) == len(self.page_images) + 1
if len(self.boxes) == 0 and zoomin < 9: self.__images__(fnm, zoomin * 3, page_from,
page_to, callback)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.__images__(fnm, zoomin)

View File

@ -129,4 +129,3 @@ class YoudaoRerank(DefaultRerank):
return np.array(res), token_count

View File

@ -48,7 +48,7 @@ class EsQueryer:
@staticmethod
def rmWWW(txt):
patts = [
(r"是*(什么样的|哪家|一下|那家|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
(r"是*(什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀)是*", ""),
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
(r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down) ", " ")
]
@ -68,7 +68,9 @@ class EsQueryer:
if not self.isChinese(txt):
tks = rag_tokenizer.tokenize(txt).split(" ")
tks_w = self.tw.weights(tks)
tks_w = [(re.sub(r"[ \\\"']+", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
q = ["{}^{:.4f}".format(tk, w) for tk, w in tks_w if tk]
for i in range(1, len(tks_w)):
q.append("\"%s %s\"^%.4f" % (tks_w[i - 1][0], tks_w[i][0], max(tks_w[i - 1][1], tks_w[i][1])*2))
@ -118,7 +120,8 @@ class EsQueryer:
if sm:
tk = f"{tk} OR \"%s\" OR (\"%s\"~2)^0.5" % (
" ".join(sm), " ".join(sm))
tms.append((tk, w))
if tk.strip():
tms.append((tk, w))
tms = " ".join([f"({t})^{w}" for t, w in tms])