mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-21 05:29:57 +08:00
166 lines
5.5 KiB
Python
Executable File
166 lines
5.5 KiB
Python
Executable File
#-*- coding:utf-8 -*-
|
||
import sys, os, re,inspect,json,traceback,logging,argparse, copy
|
||
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
|
||
from tornado.web import RequestHandler,Application
|
||
from tornado.ioloop import IOLoop
|
||
from tornado.httpserver import HTTPServer
|
||
from tornado.options import define,options
|
||
from util import es_conn, setup_logging
|
||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||
from nlp import huqie
|
||
from nlp import query as Query
|
||
from nlp import search
|
||
from llm import HuEmbedding, GptTurbo
|
||
import numpy as np
|
||
from io import BytesIO
|
||
from util import config
|
||
from timeit import default_timer as timer
|
||
from collections import OrderedDict
|
||
from llm import ChatModel, EmbeddingModel
|
||
|
||
SE = None
|
||
CFIELD="content_ltks"
|
||
EMBEDDING = EmbeddingModel
|
||
LLM = ChatModel
|
||
|
||
def get_QA_pairs(hists):
|
||
pa = []
|
||
for h in hists:
|
||
for k in ["user", "assistant"]:
|
||
if h.get(k):
|
||
pa.append({
|
||
"content": h[k],
|
||
"role": k,
|
||
})
|
||
|
||
for p in pa[:-1]: assert len(p) == 2, p
|
||
return pa
|
||
|
||
|
||
|
||
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
|
||
max_len //= len(top_i)
|
||
# add instruction to prompt
|
||
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
|
||
if len(instructions)>2:
|
||
# Said that LLM is sensitive to the first and the last one, so
|
||
# rearrange the order of references
|
||
instructions.append(copy.deepcopy(instructions[1]))
|
||
instructions.pop(1)
|
||
|
||
def token_num(txt):
|
||
c = 0
|
||
for tk in re.split(r"[,。/?‘’”“:;:;!!]", txt):
|
||
if re.match(r"[a-zA-Z-]+$", tk):
|
||
c += 1
|
||
continue
|
||
c += len(tk)
|
||
return c
|
||
|
||
_inst = ""
|
||
for ins in instructions:
|
||
if token_num(_inst) > 4096:
|
||
_inst += "\n知识库:" + instructions[-1][:max_len]
|
||
break
|
||
_inst += "\n知识库:" + ins[:max_len]
|
||
return _inst
|
||
|
||
|
||
def prompt_and_answer(history, inst):
|
||
hist = get_QA_pairs(history)
|
||
chks = []
|
||
for s in re.split(r"[::;;。\n\r]+", inst):
|
||
if s: chks.append(s)
|
||
chks = len(set(chks))/(0.1+len(chks))
|
||
print("Duplication portion:", chks)
|
||
|
||
system = """
|
||
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
|
||
以下是知识库:
|
||
%s
|
||
以上是知识库。
|
||
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
|
||
|
||
print("【PROMPT】:", system)
|
||
start = timer()
|
||
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
|
||
print("GENERATE: ", timer()-start)
|
||
print("===>>", response)
|
||
return response
|
||
|
||
|
||
class Handler(RequestHandler):
|
||
def post(self):
|
||
global SE,MUST_TK_NUM
|
||
param = json.loads(self.request.body.decode('utf-8'))
|
||
try:
|
||
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
|
||
res = SE.search({
|
||
"question": question,
|
||
"kb_ids": param.get("kb_ids", []),
|
||
"size": param.get("topn", 15)},
|
||
search.index_name(param["uid"])
|
||
)
|
||
|
||
sim = SE.rerank(res, question)
|
||
rk_idx = np.argsort(sim*-1)
|
||
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
|
||
inst = get_instruction(res, topidx)
|
||
|
||
ans, topidx = prompt_and_answer(param["history"], inst)
|
||
ans = SE.insert_citations(ans, topidx, res)
|
||
|
||
refer = OrderedDict()
|
||
docnms = {}
|
||
for i in rk_idx:
|
||
did = res.field[res.ids[i]]["doc_id"]
|
||
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
|
||
if did not in refer: refer[did] = []
|
||
refer[did].append({
|
||
"chunk_id": res.ids[i],
|
||
"content": res.field[res.ids[i]]["content_ltks"],
|
||
"image": ""
|
||
})
|
||
|
||
print("::::::::::::::", ans)
|
||
self.write(json.dumps({
|
||
"code":0,
|
||
"msg":"success",
|
||
"data":{
|
||
"uid": param["uid"],
|
||
"dialog_id": param["dialog_id"],
|
||
"assistant": ans,
|
||
"refer": [{
|
||
"did": did,
|
||
"doc_name": docnms[did],
|
||
"chunks": chunks
|
||
} for did, chunks in refer.items()]
|
||
}
|
||
}))
|
||
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
|
||
|
||
except Exception as e:
|
||
logging.error("Request 500: "+str(e))
|
||
self.write(json.dumps({
|
||
"code":500,
|
||
"msg":str(e),
|
||
"data":{}
|
||
}))
|
||
print(traceback.format_exc())
|
||
|
||
|
||
if __name__ == '__main__':
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
|
||
ARGS = parser.parse_args()
|
||
|
||
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
|
||
|
||
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
|
||
http_server = HTTPServer(app)
|
||
http_server.bind(ARGS.port)
|
||
http_server.start(3)
|
||
|
||
IOLoop.current().start()
|
||
|