ragflow/python/svr/dialog_svr.py
KevinHuSh d0db329fef add llm API (#19)
* add llm API

* refine llm API
2023-12-28 13:50:13 +08:00

166 lines
5.5 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#-*- coding:utf-8 -*-
import sys, os, re,inspect,json,traceback,logging,argparse, copy
sys.path.append(os.path.realpath(os.path.dirname(inspect.getfile(inspect.currentframe())))+"/../")
from tornado.web import RequestHandler,Application
from tornado.ioloop import IOLoop
from tornado.httpserver import HTTPServer
from tornado.options import define,options
from util import es_conn, setup_logging
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
from nlp import huqie
from nlp import query as Query
from nlp import search
from llm import HuEmbedding, GptTurbo
import numpy as np
from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict
from llm import ChatModel, EmbeddingModel
SE = None
CFIELD="content_ltks"
EMBEDDING = EmbeddingModel
LLM = ChatModel
def get_QA_pairs(hists):
pa = []
for h in hists:
for k in ["user", "assistant"]:
if h.get(k):
pa.append({
"content": h[k],
"role": k,
})
for p in pa[:-1]: assert len(p) == 2, p
return pa
def get_instruction(sres, top_i, max_len=8096, fld="content_ltks"):
max_len //= len(top_i)
# add instruction to prompt
instructions = [re.sub(r"[\r\n]+", " ", sres.field[sres.ids[i]][fld]) for i in top_i]
if len(instructions)>2:
# Said that LLM is sensitive to the first and the last one, so
# rearrange the order of references
instructions.append(copy.deepcopy(instructions[1]))
instructions.pop(1)
def token_num(txt):
c = 0
for tk in re.split(r"[,。/?‘’”“:;:;!]", txt):
if re.match(r"[a-zA-Z-]+$", tk):
c += 1
continue
c += len(tk)
return c
_inst = ""
for ins in instructions:
if token_num(_inst) > 4096:
_inst += "\n知识库:" + instructions[-1][:max_len]
break
_inst += "\n知识库:" + ins[:max_len]
return _inst
def prompt_and_answer(history, inst):
hist = get_QA_pairs(history)
chks = []
for s in re.split(r"[:;;。\n\r]+", inst):
if s: chks.append(s)
chks = len(set(chks))/(0.1+len(chks))
print("Duplication portion:", chks)
system = """
你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答%s。当所有知识库内容都与问题无关时,你的回答必须包括"知识库中未找到您要的答案!这是我所知道的,仅作参考。"这句话。回答需要考虑聊天历史。
以下是知识库:
%s
以上是知识库。
"""%((",最好总结成表格" if chks<0.6 and chks>0 else ""), inst)
print("【PROMPT】:", system)
start = timer()
response = LLM.chat(system, hist, {"temperature": 0.2, "max_tokens": 512})
print("GENERATE: ", timer()-start)
print("===>>", response)
return response
class Handler(RequestHandler):
def post(self):
global SE,MUST_TK_NUM
param = json.loads(self.request.body.decode('utf-8'))
try:
question = param.get("history",[{"user": "Hi!"}])[-1]["user"]
res = SE.search({
"question": question,
"kb_ids": param.get("kb_ids", []),
"size": param.get("topn", 15)},
search.index_name(param["uid"])
)
sim = SE.rerank(res, question)
rk_idx = np.argsort(sim*-1)
topidx = [i for i in rk_idx if sim[i] >= aram.get("similarity", 0.5)][:param.get("topn",12)]
inst = get_instruction(res, topidx)
ans, topidx = prompt_and_answer(param["history"], inst)
ans = SE.insert_citations(ans, topidx, res)
refer = OrderedDict()
docnms = {}
for i in rk_idx:
did = res.field[res.ids[i]]["doc_id"]
if did not in docnms: docnms[did] = res.field[res.ids[i]]["docnm_kwd"]
if did not in refer: refer[did] = []
refer[did].append({
"chunk_id": res.ids[i],
"content": res.field[res.ids[i]]["content_ltks"],
"image": ""
})
print("::::::::::::::", ans)
self.write(json.dumps({
"code":0,
"msg":"success",
"data":{
"uid": param["uid"],
"dialog_id": param["dialog_id"],
"assistant": ans,
"refer": [{
"did": did,
"doc_name": docnms[did],
"chunks": chunks
} for did, chunks in refer.items()]
}
}))
logging.info("SUCCESS[%d]"%(res.total)+json.dumps(param, ensure_ascii=False))
except Exception as e:
logging.error("Request 500: "+str(e))
self.write(json.dumps({
"code":500,
"msg":str(e),
"data":{}
}))
print(traceback.format_exc())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=4455, type=int, help="Port used for service")
ARGS = parser.parse_args()
SE = search.Dealer(es_conn.HuEs("infiniflow"), EMBEDDING)
app = Application([(r'/v1/chat/completions', Handler)],debug=False)
http_server = HTTPServer(app)
http_server.bind(ARGS.port)
http_server.start(3)
IOLoop.current().start()