Test chat API and refine ppt chunker (#42)

This commit is contained in:
KevinHuSh 2024-01-23 19:45:36 +08:00 committed by GitHub
parent 34b2ab3b2f
commit e32ef75e99
10 changed files with 226 additions and 91 deletions

View File

@ -17,7 +17,7 @@ from flask import request
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
@ -170,12 +170,9 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
model_config = TenantLLMService.get_api_key(dialog.tenant_id, dialog.llm_id)
if not model_config: raise LookupError("LLM({}) API key not found".format(dialog.llm_id))
question = messages[-1]["content"]
embd_mdl = TenantLLMService.model_instance(
dialog.tenant_id, LLMType.EMBEDDING.value)
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
@ -189,8 +186,7 @@ def chat(dialog, messages, **kwargs):
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
mdl = ChatModel[model_config.llm_factory](model_config.api_key, dialog.llm_id)
answer = mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],

View File

@ -524,6 +524,7 @@ class Dialog(DataBaseModel):
similarity_threshold = FloatField(default=0.2)
vector_similarity_weight = FloatField(default=0.3)
top_n = IntegerField(default=6)
do_refer = CharField(max_length=1, null=False, help_text="it needs to insert reference index into answer or not", default="1")
kb_ids = JSONField(null=False, default=[])
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")

View File

@ -14,12 +14,12 @@
# limitations under the License.
#
from api.db.services.user_service import TenantService
from rag.llm import EmbeddingModel, CvModel
from api.settings import database_logger
from rag.llm import EmbeddingModel, CvModel, ChatModel
from api.db import LLMType
from api.db.db_models import DB, UserTenant
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService
from api.db import StatusEnum
class LLMFactoriesService(CommonService):
@ -37,13 +37,19 @@ class TenantLLMService(CommonService):
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
if not objs: return
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name]
fields = [
cls.model.llm_factory,
LLMFactories.logo,
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id).dicts()
@ -51,23 +57,96 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type):
e,tenant = TenantService.get_by_id(tenant_id)
if not e: raise LookupError("Tenant not found")
def model_instance(cls, tenant_id, llm_type, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value: mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value: mdlnm = tenant.llm_id
else: assert False, "LLM type error"
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config: raise LookupError("Model({}) not found".format(mdlnm))
if not model_config:
raise LookupError("Model({}) not found".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel: return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel: return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](
model_config["api_key"], model_config["llm_name"])
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
else:
assert False, "LLM type error"
num = cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)\
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == mdlnm)\
.execute()
return num
class LLMBundle(object):
def __init__(self, tenant_id, llm_type, llm_name=None):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name)
assert self.mdl, "Can't find mole for {}/{}/{}".format(tenant_id, llm_type, llm_name)
def encode(self, texts: list, batch_size=32):
emd, used_tokens = self.mdl.encode(texts, batch_size)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def encode_queries(self, query: str):
emd, used_tokens = self.mdl.encode_queries(query)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/EMBEDDING".format(self.tenant_id))
return emd, used_tokens
def describe(self, image, max_tokens=300):
txt, used_tokens = self.mdl.describe(image, max_tokens)
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
database_logger.error("Can't update token usage for {}/IMAGE2TEXT".format(self.tenant_id))
return txt
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
if TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
database_logger.error("Can't update token usage for {}/CHAT".format(self.tenant_id))
return txt

View File

@ -143,11 +143,11 @@ def filename_type(filename):
if re.match(r".*\.pdf$", filename):
return FileType.PDF.value
if re.match(r".*\.(docx|doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
if re.match(r".*\.(docx|doc|ppt|pptx|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
return FileType.DOC.value
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
return FileType.VISUAL
return FileType.VISUAL

View File

@ -37,7 +37,7 @@ class GptTurbo(Base):
model=self.model_name,
messages=history,
**gen_conf)
return res.choices[0].message.content.strip()
return res.choices[0].message.content.strip(), res.usage.completion_tokens
from dashscope import Generation
@ -56,5 +56,5 @@ class QWenChat(Base):
result_format='message'
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0

View File

@ -72,7 +72,7 @@ class GptV4(Base):
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip()
return res.choices[0].message.content.strip(), res.usage.total_tokens
class QWenCV(Base):
@ -87,5 +87,5 @@ class QWenCV(Base):
response = MultiModalConversation.call(model=self.model_name,
messages=self.prompt(self.image2base64(image)))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.message, 0

View File

@ -36,6 +36,9 @@ class Base(ABC):
def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!")
def encode_queries(self, text: str):
raise NotImplementedError("Please implement encode method!")
class HuEmbedding(Base):
def __init__(self, key="", model_name=""):
@ -68,15 +71,18 @@ class HuEmbedding(Base):
class OpenAIEmbed(Base):
def __init__(self, key, model_name="text-embedding-ada-002"):
self.client = OpenAI(key)
self.client = OpenAI(api_key=key)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
token_count = 0
for t in texts: token_count += num_tokens_from_string(t)
res = self.client.embeddings.create(input=texts,
model=self.model_name)
return [d["embedding"] for d in res["data"]], token_count
return np.array([d.embedding for d in res.data]), res.usage.total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=[text],
model=self.model_name)
return np.array(res.data[0].embedding), res.usage.total_tokens
class QWenEmbed(Base):
@ -84,16 +90,28 @@ class QWenEmbed(Base):
dashscope.api_key = key
self.model_name = model_name
def encode(self, texts: list, batch_size=32, text_type="document"):
def encode(self, texts: list, batch_size=10):
import dashscope
res = []
token_count = 0
for txt in texts:
texts = [txt[:2048] for txt in texts]
for i in range(0, len(texts), batch_size):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=txt[:2048],
text_type=text_type
input=texts[i:i+batch_size],
text_type="document"
)
res.append(resp["output"]["embeddings"][0]["embedding"])
token_count += resp["usage"]["total_tokens"]
return res, token_count
embds = [[]] * len(resp["output"]["embeddings"])
for e in resp["output"]["embeddings"]:
embds[e["text_index"]] = e["embedding"]
res.extend(embds)
token_count += resp["usage"]["input_tokens"]
return np.array(res), token_count
def encode_queries(self, text):
resp = dashscope.TextEmbedding.call(
model=self.model_name,
input=text[:2048],
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]

View File

@ -11,6 +11,11 @@ from io import BytesIO
class HuChunker:
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self):
self.MAX_LVL = 12
self.proj_patt = [
@ -228,11 +233,6 @@ class HuChunker:
class PdfChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, pdf_parser):
self.pdf = pdf_parser
super().__init__()
@ -293,11 +293,6 @@ class PdfChunker(HuChunker):
class DocxChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, doc_parser):
self.doc = doc_parser
super().__init__()
@ -344,11 +339,6 @@ class DocxChunker(HuChunker):
class ExcelChunker(HuChunker):
@dataclass
class Fields:
text_chunks: List = None
table_chunks: List = None
def __init__(self, excel_parser):
self.excel = excel_parser
super().__init__()
@ -370,18 +360,51 @@ class PptChunker(HuChunker):
def __init__(self):
super().__init__()
def __extract(self, shape):
if shape.shape_type == 19:
tb = shape.table
rows = []
for i in range(1, len(tb.rows)):
rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
return "\n".join(rows)
if shape.has_text_frame:
return shape.text_frame.text
if shape.shape_type == 6:
texts = []
for p in shape.shapes:
t = self.__extract(p)
if t: texts.append(t)
return "\n".join(texts)
def __call__(self, fnm):
from pptx import Presentation
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
flds = self.Fields()
flds.text_chunks = []
txts = []
for slide in ppt.slides:
texts = []
for shape in slide.shapes:
if hasattr(shape, "text"):
flds.text_chunks.append((shape.text, None))
txt = self.__extract(shape)
if txt: texts.append(txt)
txts.append("\n".join(texts))
import aspose.slides as slides
import aspose.pydrawing as drawing
imgs = []
with slides.Presentation(BytesIO(fnm)) as presentation:
for slide in presentation.slides:
buffered = BytesIO()
slide.get_thumbnail(0.5, 0.5).save(buffered, drawing.imaging.ImageFormat.jpeg)
imgs.append(buffered.getvalue())
assert len(imgs) == len(txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
flds = self.Fields()
flds.text_chunks = [(txts[i], imgs[i]) for i in range(len(txts))]
flds.table_chunks = []
return flds

View File

@ -58,7 +58,8 @@ class Dealer:
if req["available_int"] == 0:
bqry.filter.append(Q("range", available_int={"lt": 1}))
else:
bqry.filter.append(Q("bool", must_not=Q("range", available_int={"lt": 1})))
bqry.filter.append(
Q("bool", must_not=Q("range", available_int={"lt": 1})))
bqry.boost = 0.05
s = Search()
@ -87,9 +88,12 @@ class Dealer:
q_vec = []
if req.get("vector"):
assert emb_mdl, "No embedding model selected"
s["knn"] = self._vector(qst, emb_mdl, req.get("similarity", 0.4), ps)
s["knn"] = self._vector(
qst, emb_mdl, req.get(
"similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict()
if "highlight" in s: del s["highlight"]
if "highlight" in s:
del s["highlight"]
q_vec = s["knn"]["query_vector"]
es_logger.info("【Q】: {}".format(json.dumps(s)))
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
@ -175,7 +179,8 @@ class Dealer:
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
def insert_citations(self, answer, chunks, chunk_v, embd_mdl, tkweight=0.3, vtweight=0.7):
def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.3, vtweight=0.7):
pieces = re.split(r"([;。?!\n]|[a-z][.?;!][ \n])", answer)
for i in range(1, len(pieces)):
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
@ -184,47 +189,57 @@ class Dealer:
idx = []
pieces_ = []
for i, t in enumerate(pieces):
if len(t) < 5: continue
if len(t) < 5:
continue
idx.append(i)
pieces_.append(t)
es_logger.info("{} => {}".format(answer, pieces_))
if not pieces_: return answer
if not pieces_:
return answer
ans_v, c = embd_mdl.encode(pieces_)
ans_v, _ = embd_mdl.encode(pieces_)
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
chunks_tks = [huqie.qie(ck).split(" ") for ck in chunks]
cites = {}
for i,a in enumerate(pieces_):
for i, a in enumerate(pieces_):
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
chunk_v,
huqie.qie(pieces_[i]).split(" "),
huqie.qie(
pieces_[i]).split(" "),
chunks_tks,
tkweight, vtweight)
mx = np.max(sim) * 0.99
if mx < 0.55: continue
cites[idx[i]] = list(set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
if mx < 0.55:
continue
cites[idx[i]] = list(
set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
res = ""
for i,p in enumerate(pieces):
for i, p in enumerate(pieces):
res += p
if i not in idx:continue
if i not in cites:continue
res += "##%s$$"%"$".join(cites[i])
if i not in idx:
continue
if i not in cites:
continue
res += "##%s$$" % "$".join(cites[i])
return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks"):
def rerank(self, sres, query, tkweight=0.3,
vtweight=0.7, cfield="content_ltks"):
ins_embd = [
Dealer.trans2floats(
sres.field[i]["q_%d_vec" % len(sres.query_vector)]) for i in sres.ids]
sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
if not ins_embd:
return [], [], []
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ") for i in sres.ids]
ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
for i in sres.ids]
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "),
huqie.qie(
query).split(" "),
ins_tw, tkweight, vtweight)
return sim, tksim, vtsim
@ -237,7 +252,8 @@ class Dealer:
def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, similarity_threshold=0.2,
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True):
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question: return ranks
if not question:
return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": top,
"question": question, "vector": True,
"similarity": similarity_threshold}

View File

@ -49,7 +49,7 @@ from rag.nlp.huchunk import (
)
from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import TenantLLMService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from api.settings import database_logger
from api.utils import get_format_time
from api.utils.file_utils import get_project_base_directory
@ -62,7 +62,7 @@ EXC = ExcelChunker(ExcelParser())
PPT = PptChunker()
def chuck_doc(name, binary, cvmdl=None):
def chuck_doc(name, binary, tenant_id, cvmdl=None):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0:
return PDF(binary)
@ -127,7 +127,7 @@ def build(row, cvmdl):
100., "Finished preparing! Start to slice file!", True)
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), row["tenant_id"], cvmdl)
except Exception as e:
if re.search("(No such file|not found)", str(e)):
set_progress(
@ -236,12 +236,14 @@ def main(comm, mod):
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
embd_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.EMBEDDING)
if not embd_mdl:
set_progress(r["id"], -1, "Can't find embedding model!")
cron_logger.error("Tenant({}) can't find embedding model!".format(r["tenant_id"]))
try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT)
#TODO: sequence2text model
except Exception as e:
set_progress(r["id"], -1, str(e))
continue
cv_mdl = TenantLLMService.model_instance(r["tenant_id"], LLMType.IMAGE2TEXT)
st_tm = timer()
cks = build(r, cv_mdl)
if not cks: