add llm API (#19)

* add llm API

* refine llm API
This commit is contained in:
KevinHuSh 2023-12-28 13:50:13 +08:00 committed by GitHub
parent cdd956568d
commit d0db329fef
17 changed files with 349 additions and 170 deletions

View File

@ -121,7 +121,6 @@
"match": "*_vec", "match": "*_vec",
"mapping": { "mapping": {
"type": "dense_vector", "type": "dense_vector",
"dims": 1024,
"index": true, "index": true,
"similarity": "cosine" "similarity": "cosine"
} }

View File

@ -1,10 +1,9 @@
[infiniflow] [infiniflow]
es=http://es01:9200 es=http://es01:9200
pgdb_usr=root postgres_user=root
pgdb_pwd=infiniflow_docgpt postgres_password=infiniflow_docgpt
pgdb_host=postgres postgres_host=postgres
pgdb_port=5432 postgres_port=5432
minio_host=minio:9000 minio_host=minio:9000
minio_usr=infiniflow minio_user=infiniflow
minio_pwd=infiniflow_docgpt minio_password=infiniflow_docgpt

View File

@ -1,2 +1,21 @@
from .embedding_model import HuEmbedding import os
from .chat_model import GptTurbo from .embedding_model import *
from .chat_model import *
from .cv_model import *
EmbeddingModel = None
ChatModel = None
CvModel = None
if os.environ.get("OPENAI_API_KEY"):
EmbeddingModel = GptEmbed()
ChatModel = GptTurbo()
CvModel = GptV4()
elif os.environ.get("DASHSCOPE_API_KEY"):
EmbeddingModel = QWenEmbd()
ChatModel = QWenChat()
CvModel = QWenCV()
else:
EmbeddingModel = HuEmbedding()

View File

@ -1,7 +1,8 @@
from abc import ABC from abc import ABC
import openapi from openai import OpenAI
import os import os
class Base(ABC): class Base(ABC):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
@ -9,26 +10,27 @@ class Base(ABC):
class GptTurbo(Base): class GptTurbo(Base):
def __init__(self): def __init__(self):
openapi.api_key = os.environ["OPENAPI_KEY"] self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system}) history.insert(0, {"role": "system", "content": system})
res = openapi.ChatCompletion.create(model="gpt-3.5-turbo", res = self.client.chat.completions.create(
messages=history, model="gpt-3.5-turbo",
**gen_conf) messages=history,
**gen_conf)
return res.choices[0].message.content.strip() return res.choices[0].message.content.strip()
class QWen(Base): class QWenChat(Base):
def chat(self, system, history, gen_conf): def chat(self, system, history, gen_conf):
from http import HTTPStatus from http import HTTPStatus
from dashscope import Generation from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY # export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
history.insert(0, {"role": "system", "content": system})
response = Generation.call( response = Generation.call(
Generation.Models.qwen_turbo, Generation.Models.qwen_turbo,
messages=messages, messages=history,
result_format='message' result_format='message'
) )
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'] return response.output.choices[0]['message']['content']

66
python/llm/cv_model.py Normal file
View File

@ -0,0 +1,66 @@
from abc import ABC
from openai import OpenAI
import os
import base64
from io import BytesIO
class Base(ABC):
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
def image2base64(self, image):
if isinstance(image, BytesIO):
return base64.b64encode(image.getvalue()).decode("utf-8")
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def prompt(self, b64):
return [
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}
]
class GptV4(Base):
def __init__(self):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def describe(self, image, max_tokens=300):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip()
class QWenCV(Base):
def describe(self, image, max_tokens=300):
from http import HTTPStatus
from dashscope import MultiModalConversation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
messages=self.prompt(self.image2base64(image)))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,8 +1,11 @@
from abc import ABC from abc import ABC
from openai import OpenAI
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
import torch import torch
import os
import numpy as np import numpy as np
class Base(ABC): class Base(ABC):
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
@ -22,11 +25,37 @@ class HuEmbedding(Base):
""" """
self.model = FlagModel("BAAI/bge-large-zh-v1.5", self.model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available()) use_fp16=torch.cuda.is_available())
def encode(self, texts: list, batch_size=32): def encode(self, texts: list, batch_size=32):
res = [] res = []
for i in range(0, len(texts), batch_size): for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i+batch_size]).tolist()) res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
return np.array(res) return np.array(res)
class GptEmbed(Base):
def __init__(self):
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model="text-embedding-ada-002")
return [d["embedding"] for d in res["data"]]
class QWenEmbd(Base):
def encode(self, texts: list, batch_size=32, text_type="document"):
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
import dashscope
from http import HTTPStatus
res = []
for txt in texts:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v2,
input=txt[:2048],
text_type=text_type
)
res.append(resp["output"]["embeddings"][0]["embedding"])
return res

View File

@ -372,7 +372,9 @@ class PptChunker(HuChunker):
def __call__(self, fnm): def __call__(self, fnm):
from pptx import Presentation from pptx import Presentation
ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm)) ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
flds = self.Fields() flds = self.Fields()
flds.text_chunks = [] flds.text_chunks = []
for slide in ppt.slides: for slide in ppt.slides:
@ -398,7 +400,8 @@ class TextChunker(HuChunker):
mime = magic.Magic(mime=True) mime = magic.Magic(mime=True)
if isinstance(file_path, str): if isinstance(file_path, str):
file_type = mime.from_file(file_path) file_type = mime.from_file(file_path)
else:file_type = mime.from_buffer(file_path) else:
file_type = mime.from_buffer(file_path)
if 'text' in file_type: if 'text' in file_type:
return False return False
else: else:
@ -406,7 +409,8 @@ class TextChunker(HuChunker):
def __call__(self, fnm): def __call__(self, fnm):
flds = self.Fields() flds = self.Fields()
if self.is_binary_file(fnm):return flds if self.is_binary_file(fnm):
return flds
with open(fnm, "r") as f: with open(fnm, "r") as f:
txt = f.read() txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)] flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]

View File

@ -1,6 +1,6 @@
import re import re
from elasticsearch_dsl import Q,Search,A from elasticsearch_dsl import Q, Search, A
from typing import List, Optional, Tuple,Dict, Union from typing import List, Optional, Tuple, Dict, Union
from dataclasses import dataclass from dataclasses import dataclass
from util import setup_logging, rmSpace from util import setup_logging, rmSpace
from nlp import huqie, query from nlp import huqie, query
@ -9,18 +9,24 @@ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
def index_name(uid):return f"docgpt_{uid}"
def index_name(uid): return f"docgpt_{uid}"
class Dealer: class Dealer:
def __init__(self, es, emb_mdl): def __init__(self, es, emb_mdl):
self.qryr = query.EsQueryer(es) self.qryr = query.EsQueryer(es)
self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"] self.qryr.flds = [
"title_tks^10",
"title_sm_tks^5",
"content_ltks^2",
"content_sm_ltks"]
self.es = es self.es = es
self.emb_mdl = emb_mdl self.emb_mdl = emb_mdl
@dataclass @dataclass
class SearchResult: class SearchResult:
total:int total: int
ids: List[str] ids: List[str]
query_vector: List[float] = None query_vector: List[float] = None
field: Optional[Dict] = None field: Optional[Dict] = None
@ -42,71 +48,78 @@ class Dealer:
keywords = [] keywords = []
qst = req.get("question", "") qst = req.get("question", "")
bqry,keywords = self.qryr.question(qst) bqry, keywords = self.qryr.question(qst)
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry.filter.append(Q("exists", field="q_tks")) bqry.filter.append(Q("exists", field="q_tks"))
bqry.boost = 0.05 bqry.boost = 0.05
print(bqry) print(bqry)
s = Search() s = Search()
pg = int(req.get("page", 1))-1 pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000)) ps = int(req.get("size", 1000))
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id", src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
"image_id", "doc_id", "q_vec"]) "image_id", "doc_id", "q_vec"])
s = s.query(bqry)[pg*ps:(pg+1)*ps] s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks") s = s.highlight("content_ltks")
s = s.highlight("title_ltks") s = s.highlight("title_ltks")
if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}}) if not qst:
s = s.sort(
{"create_time": {"order": "desc", "unmapped_type": "date"}})
s = s.highlight_options( s = s.highlight_options(
fragment_size = 120, fragment_size=120,
number_of_fragments=5, number_of_fragments=5,
boundary_scanner_locale="zh-CN", boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE", boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、" boundary_chars=",./;:\\!(),。?:!……()——、"
) )
s = s.to_dict() s = s.to_dict()
q_vec = [] q_vec = []
if req.get("vector"): if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps) s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
del s["highlight"] del s["highlight"]
q_vec = s["knn"]["query_vector"] q_vec = s["knn"]["query_vector"]
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src) res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
print("TOTAL: ", self.es.getTotal(res)) print("TOTAL: ", self.es.getTotal(res))
if self.es.getTotal(res) == 0 and "knn" in s: if self.es.getTotal(res) == 0 and "knn" in s:
bqry,_ = self.qryr.question(qst, min_match="10%") bqry, _ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"])) if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
s["query"] = bqry.to_dict() s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict() s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.7 s["knn"]["similarity"] = 0.7
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src) res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
kwds = set([]) kwds = set([])
for k in keywords: for k in keywords:
kwds.add(k) kwds.add(k)
for kk in huqie.qieqie(k).split(" "): for kk in huqie.qieqie(k).split(" "):
if len(kk) < 2:continue if len(kk) < 2:
if kk in kwds:continue continue
if kk in kwds:
continue
kwds.add(kk) kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd") aggs = self.getAggregation(res, "docnm_kwd")
return self.SearchResult( return self.SearchResult(
total = self.es.getTotal(res), total=self.es.getTotal(res),
ids = self.es.getDocIds(res), ids=self.es.getDocIds(res),
query_vector = q_vec, query_vector=q_vec,
aggregation = aggs, aggregation=aggs,
highlight = self.getHighlight(res), highlight=self.getHighlight(res),
field = self.getFields(res, ["docnm_kwd", "content_ltks", field=self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id","image_id", "doc_id", "q_vec"]), "kb_id", "image_id", "doc_id", "q_vec"]),
keywords = list(kwds) keywords=list(kwds)
) )
def getAggregation(self, res, g): def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
bkts = res["aggregations"]["aggs_"+g]["buckets"] return
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts] return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res): def getHighlight(self, res):
@ -114,8 +127,11 @@ class Dealer:
eng = set(list("qwertyuioplkjhgfdsazxcvbnm")) eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = [] r = []
for t in line.split(" "): for t in line.split(" "):
if not t:continue if not t:
if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ") continue
if len(r) > 0 and len(
t) > 0 and r[-1][-1] in eng and t[0] in eng:
r.append(" ")
r.append(t) r.append(t)
r = "".join(r) r = "".join(r)
return r return r
@ -123,66 +139,76 @@ class Dealer:
ans = {} ans = {}
for d in res["hits"]["hits"]: for d in res["hits"]["hits"]:
hlts = d.get("highlight") hlts = d.get("highlight")
if not hlts:continue if not hlts:
continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]]) ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
return ans return ans
def getFields(self, sres, flds): def getFields(self, sres, flds):
res = {} res = {}
if not flds:return {} if not flds:
for d in self.es.getSource(sres): return {}
m = {n:d.get(n) for n in flds if d.get(n) is not None} for d in self.es.getSource(sres):
for n,v in m.items(): m = {n: d.get(n) for n in flds if d.get(n) is not None}
if type(v) == type([]): for n, v in m.items():
if isinstance(v, type([])):
m[n] = "\t".join([str(vv) for vv in v]) m[n] = "\t".join([str(vv) for vv in v])
continue continue
if type(v) != type(""):m[n] = str(m[n]) if not isinstance(v, type("")):
m[n] = str(m[n])
m[n] = rmSpace(m[n]) m[n] = rmSpace(m[n])
if m:res[d["id"]] = m if m:
res[d["id"]] = m
return res return res
@staticmethod @staticmethod
def trans2floats(txt): def trans2floats(txt):
return [float(t) for t in txt.split("\t")] return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres,
vfield="q_vec", cfield="content_ltks"):
def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"): ins_embd = [Dealer.trans2floats(
sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx] ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
s = 0 s = 0
e = 0 e = 0
res = "" res = ""
def citeit(): def citeit():
nonlocal s, e, ans, res nonlocal s, e, ans, res
if not ins_embd:return if not ins_embd:
return
embd = self.emb_mdl.encode(ans[s: e]) embd = self.emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd, sim = self.qryr.hybrid_similarity(embd,
ins_embd, ins_embd,
huqie.qie(ans[s:e]).split(" "), huqie.qie(ans[s:e]).split(" "),
ins_tw) ins_tw)
print(ans[s: e], sim) print(ans[s: e], sim)
mx = np.max(sim)*0.99 mx = np.max(sim) * 0.99
if mx < 0.55:return if mx < 0.55:
cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4] return
for i in cita: res += f"@?{i}?@" cita = list(set([top_idx[i]
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
for i in cita:
res += f"@?{i}?@"
return cita return cita
punct = set(";。?!") punct = set(";。?!")
if not self.qryr.isChinese(ans): if not self.qryr.isChinese(ans):
punct.add("?") punct.add("?")
punct.add(".") punct.add(".")
while e < len(ans): while e < len(ans):
if e - s < 12 or ans[e] not in punct: if e - s < 12 or ans[e] not in punct:
e += 1 e += 1
continue continue
if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]): if ans[e] == "." and e + \
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
e += 1 e += 1
continue continue
if ans[e] == "." and e-2>=0 and ans[e-2] == "\n": if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
e += 1 e += 1
continue continue
res += ans[s: e] res += ans[s: e]
@ -191,33 +217,36 @@ class Dealer:
e += 1 e += 1
s = e s = e
if s< len(ans): if s < len(ans):
res += ans[s:] res += ans[s:]
citeit() citeit()
return res return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"): vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids] ins_embd = [
if not ins_embd: return [] Dealer.trans2floats(
ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids] sres.field[i]["q_vec"]) for i in sres.ids]
#return CosineSimilarity([sres.query_vector], ins_embd)[0] if not ins_embd:
sim = self.qryr.hybrid_similarity(sres.query_vector, return []
ins_embd, ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "), huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight) ins_tw, tkweight, vtweight)
return sim return sim
if __name__ == "__main__":
if __name__ == "__main__":
from util import es_conn from util import es_conn
SE = Dealer(es_conn.HuEs("infiniflow")) SE = Dealer(es_conn.HuEs("infiniflow"))
qs = [ qs = [
"胡凯", "胡凯",
"" ""
] ]
for q in qs: for q in qs:
print(">>>>>>>>>>>>>>>>>>>>", q) print(">>>>>>>>>>>>>>>>>>>>", q)
print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*")) print(SE.search(
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))

View File

@ -5,8 +5,10 @@ from io import BytesIO
class HuExcelParser: class HuExcelParser:
def __call__(self, fnm): def __call__(self, fnm):
if isinstance(fnm, str):wb = load_workbook(fnm) if isinstance(fnm, str):
else: wb = load_workbook(BytesIO(fnm)) wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(fnm))
res = [] res = []
for sheetname in wb.sheetnames: for sheetname in wb.sheetnames:
ws = wb[sheetname] ws = wb[sheetname]

View File

@ -53,7 +53,7 @@ class HuParser:
def _y_dis( def _y_dis(
self, a, b): self, a, b):
return ( return (
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
def _match_proj(self, b): def _match_proj(self, b):
proj_patt = [ proj_patt = [
@ -76,9 +76,9 @@ class HuParser:
tks_down = huqie.qie(down["text"][:LEN]).split(" ") tks_down = huqie.qie(down["text"][:LEN]).split(" ")
tks_up = huqie.qie(up["text"][-LEN:]).split(" ") tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
tks_all = up["text"][-LEN:].strip() \ tks_all = up["text"][-LEN:].strip() \
+ (" " if re.match(r"[a-zA-Z0-9]+", + (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \ up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip() + down["text"][:LEN].strip()
tks_all = huqie.qie(tks_all).split(" ") tks_all = huqie.qie(tks_all).split(" ")
fea = [ fea = [
up.get("R", -1) == down.get("R", -1), up.get("R", -1) == down.get("R", -1),
@ -100,7 +100,7 @@ class HuParser:
True if re.search(r"[,][^。.]+$", up["text"]) else False, True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[,][^。.]+$", up["text"]) else False, True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[\(][^\)]+$", up["text"]) True if re.search(r"[\(][^\)]+$", up["text"])
and re.search(r"[\)]", down["text"]) else False, and re.search(r"[\)]", down["text"]) else False,
self._match_proj(down), self._match_proj(down),
True if re.match(r"[A-Z]", down["text"]) else False, True if re.match(r"[A-Z]", down["text"]) else False,
True if re.match(r"[A-Z]", up["text"][-1]) else False, True if re.match(r"[A-Z]", up["text"][-1]) else False,
@ -217,7 +217,7 @@ class HuParser:
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format( assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
tp, btm, x0, x1, b) tp, btm, x0, x1, b)
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \ ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
x0 != 0 and btm - tp != 0 else 0 x0 != 0 and btm - tp != 0 else 0
if ov > 0 and ratio: if ov > 0 and ratio:
ov /= (x1 - x0) * (btm - tp) ov /= (x1 - x0) * (btm - tp)
return ov return ov
@ -382,7 +382,7 @@ class HuParser:
continue continue
for tb in tbls: # for table for tb in tbls: # for table
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
tb["x1"] + MARGIN, tb["bottom"] + MARGIN tb["x1"] + MARGIN, tb["bottom"] + MARGIN
left *= ZM left *= ZM
top *= ZM top *= ZM
right *= ZM right *= ZM
@ -899,7 +899,7 @@ class HuParser:
lst_r = rows[-1] lst_r = rows[-1]
if lst_r[-1].get("R", "") != b.get("R", "") \ if lst_r[-1].get("R", "") != b.get("R", "") \
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
): # new row ): # new row
btm = b["bottom"] btm = b["bottom"]
b["rn"] += 1 b["rn"] += 1
rows.append([b]) rows.append([b])
@ -949,9 +949,9 @@ class HuParser:
j += 1 j += 1
continue continue
f = (j > 0 and tbl[ii][j - 1] and tbl[ii] f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
[j - 1][0].get("text")) or j == 0 [j - 1][0].get("text")) or j == 0
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
if f and ff: if f and ff:
j += 1 j += 1
continue continue
@ -1012,9 +1012,9 @@ class HuParser:
i += 1 i += 1
continue continue
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
[jj][0].get("text")) or i == 0 [jj][0].get("text")) or i == 0
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
[jj][0].get("text")) or i + 1 >= len(tbl) [jj][0].get("text")) or i + 1 >= len(tbl)
if f and ff: if f and ff:
i += 1 i += 1
continue continue
@ -1169,8 +1169,8 @@ class HuParser:
else "") + headers[j - 1][k] else "") + headers[j - 1][k]
else: else:
headers[j][k] = headers[j - 1][k] \ headers[j][k] = headers[j - 1][k] \
+ ("" if headers[j - 1][k] else "") \ + ("" if headers[j - 1][k] else "") \
+ headers[j][k] + headers[j][k]
logging.debug( logging.debug(
f">>>>>>>>>>>>>>>>>{cap}SIZE:{rowno}X{clmno} Header: {hdr_rowno}") f">>>>>>>>>>>>>>>>>{cap}SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
@ -1247,7 +1247,7 @@ class HuParser:
i += 1 i += 1
continue continue
lout_no = str(self.boxes[i]["page_number"]) + \ lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"]) "-" + str(self.boxes[i]["layoutno"])
if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title", if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
"figure caption", "reference"]: "figure caption", "reference"]:
nomerge_lout_no.append(lst_lout_no) nomerge_lout_no.append(lst_lout_no)
@ -1526,7 +1526,8 @@ class HuParser:
return "\n\n".join(res) return "\n\n".join(res)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False): def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm)) self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.lefted_chars = [] self.lefted_chars = []
self.mean_height = [] self.mean_height = []
self.mean_width = [] self.mean_width = []
@ -1601,7 +1602,7 @@ class HuParser:
self.page_images[pns[0]].crop((left * ZM, top * ZM, self.page_images[pns[0]].crop((left * ZM, top * ZM,
right * right *
ZM, min( ZM, min(
bottom, self.page_images[pns[0]].size[1]) bottom, self.page_images[pns[0]].size[1])
)) ))
) )
bottom -= self.page_images[pns[0]].size[1] bottom -= self.page_images[pns[0]].size[1]

View File

@ -16,11 +16,12 @@ from io import BytesIO
from util import config from util import config
from timeit import default_timer as timer from timeit import default_timer as timer
from collections import OrderedDict from collections import OrderedDict
from llm import ChatModel, EmbeddingModel
SE = None SE = None
CFIELD="content_ltks" CFIELD="content_ltks"
EMBEDDING = HuEmbedding() EMBEDDING = EmbeddingModel
LLM = GptTurbo() LLM = ChatModel
def get_QA_pairs(hists): def get_QA_pairs(hists):
pa = [] pa = []

View File

@ -1,4 +1,4 @@
import json, os, sys, hashlib, copy, time, random, re, logging, torch import json, os, sys, hashlib, copy, time, random, re
from os.path import dirname, realpath from os.path import dirname, realpath
sys.path.append(dirname(realpath(__file__)) + "/../") sys.path.append(dirname(realpath(__file__)) + "/../")
from util.es_conn import HuEs from util.es_conn import HuEs
@ -7,10 +7,10 @@ from util.minio_conn import HuMinio
from util import rmSpace, findMaxDt from util import rmSpace, findMaxDt
from FlagEmbedding import FlagModel from FlagEmbedding import FlagModel
from nlp import huchunk, huqie, search from nlp import huchunk, huqie, search
import base64, hashlib
from io import BytesIO from io import BytesIO
import pandas as pd import pandas as pd
from elasticsearch_dsl import Q from elasticsearch_dsl import Q
from PIL import Image
from parser import ( from parser import (
PdfParser, PdfParser,
DocxParser, DocxParser,
@ -40,6 +40,15 @@ def chuck_doc(name, binary):
if suff.find("doc") >= 0: return DOC(binary) if suff.find("doc") >= 0: return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary) if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
if suff.find("ppt") >= 0: return PPT(binary) if suff.find("ppt") >= 0: return PPT(binary)
if os.envirement.get("PARSE_IMAGE") \
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
name.lower()):
from llm import CvModel
txt = CvModel.describe(binary)
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return TextChunker()(binary) return TextChunker()(binary)
@ -119,7 +128,6 @@ def build(row):
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", "")) set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
return [] return []
print(row["doc_name"], obj)
if not obj.text_chunks and not obj.table_chunks: if not obj.text_chunks and not obj.table_chunks:
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.") set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
return [] return []
@ -146,7 +154,10 @@ def build(row):
if not img: if not img:
docs.append(d) docs.append(d)
continue continue
img.save(output_buffer, format='JPEG')
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
else: output_buffer = BytesIO(img)
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"], MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue()) output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"]) d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])

View File

@ -1,19 +1,24 @@
import re import re
def rmSpace(txt): def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt) txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt) return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
def findMaxDt(fnm): def findMaxDt(fnm):
m = "1970-01-01 00:00:00" m = "1970-01-01 00:00:00"
try: try:
with open(fnm, "r") as f: with open(fnm, "r") as f:
while True: while True:
l = f.readline() l = f.readline()
if not l:break if not l:
break
l = l.strip("\n") l = l.strip("\n")
if l == 'nan':continue if l == 'nan':
if l > m:m = l continue
if l > m:
m = l
except Exception as e: except Exception as e:
print("WARNING: can't find "+ fnm) print("WARNING: can't find " + fnm)
return m return m

View File

@ -1,25 +1,31 @@
from configparser import ConfigParser from configparser import ConfigParser
import os,inspect import os
import inspect
CF = ConfigParser() CF = ConfigParser()
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf') __fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
if not os.path.exists(__fnm):__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf') if not os.path.exists(__fnm):
assert os.path.exists(__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__) __fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
if not os.path.exists(__fnm): __fnm = "./sys.cnf" assert os.path.exists(
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
if not os.path.exists(__fnm):
__fnm = "./sys.cnf"
CF.read(__fnm) CF.read(__fnm)
class Config: class Config:
def __init__(self, env): def __init__(self, env):
self.env = env self.env = env
if env == "spark":CF.read("./cv.cnf") if env == "spark":
CF.read("./cv.cnf")
def get(self, key, default=None): def get(self, key, default=None):
global CF global CF
return os.environ.get(key.upper(), \ return os.environ.get(key.upper(),
CF[self.env].get(key, default) CF[self.env].get(key, default)
) )
def init(env): def init(env):
return Config(env) return Config(env)

View File

@ -3,6 +3,7 @@ import time
from util import config from util import config
import pandas as pd import pandas as pd
class Postgres(object): class Postgres(object):
def __init__(self, env, dbnm): def __init__(self, env, dbnm):
self.config = config.init(env) self.config = config.init(env)
@ -13,36 +14,42 @@ class Postgres(object):
def __open__(self): def __open__(self):
import psycopg2 import psycopg2
try: try:
if self.conn:self.__close__() if self.conn:
self.__close__()
del self.conn del self.conn
except Exception as e: except Exception as e:
pass pass
try: try:
self.conn = psycopg2.connect(f"dbname={self.dbnm} user={self.config.get('pgdb_usr')} password={self.config.get('pgdb_pwd')} host={self.config.get('pgdb_host')} port={self.config.get('pgdb_port')}") self.conn = psycopg2.connect(f"""dbname={self.dbnm}
user={self.config.get('postgres_user')}
password={self.config.get('postgres_password')}
host={self.config.get('postgres_host')}
port={self.config.get('postgres_port')}""")
except Exception as e: except Exception as e:
logging.error("Fail to connect %s "%self.config.get("pgdb_host") + str(e)) logging.error(
"Fail to connect %s " %
self.config.get("pgdb_host") + str(e))
def __close__(self): def __close__(self):
try: try:
self.conn.close() self.conn.close()
except Exception as e: except Exception as e:
logging.error("Fail to close %s "%self.config.get("pgdb_host") + str(e)) logging.error(
"Fail to close %s " %
self.config.get("pgdb_host") + str(e))
def select(self, sql): def select(self, sql):
for _ in range(10): for _ in range(10):
try: try:
return pd.read_sql(sql, self.conn) return pd.read_sql(sql, self.conn)
except Exception as e: except Exception as e:
logging.error(f"Fail to exec {sql} "+str(e)) logging.error(f"Fail to exec {sql} " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return pd.DataFrame() return pd.DataFrame()
def update(self, sql): def update(self, sql):
for _ in range(10): for _ in range(10):
try: try:
@ -53,11 +60,11 @@ class Postgres(object):
cur.close() cur.close()
return updated_rows return updated_rows
except Exception as e: except Exception as e:
logging.error(f"Fail to exec {sql} "+str(e)) logging.error(f"Fail to exec {sql} " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return 0 return 0
if __name__ == "__main__": if __name__ == "__main__":
Postgres("infiniflow", "docgpt") Postgres("infiniflow", "docgpt")

View File

@ -228,7 +228,8 @@ class HuEs:
return False return False
def search(self, q, idxnm=None, src=False, timeout="2s"): def search(self, q, idxnm=None, src=False, timeout="2s"):
if not isinstance(q, dict): q = Search().query(q).to_dict() if not isinstance(q, dict):
q = Search().query(q).to_dict()
for i in range(3): for i in range(3):
try: try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm), res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
@ -274,9 +275,10 @@ class HuEs:
return False return False
def updateScriptByQuery(self, q, scripts, idxnm=None): def updateScriptByQuery(self, q, scripts, idxnm=None):
ubq = UpdateByQuery(index=self.idxnm if not idxnm else idxnm).using(self.es).query(q) ubq = UpdateByQuery(
index=self.idxnm if not idxnm else idxnm).using(
self.es).query(q)
ubq = ubq.script(source=scripts) ubq = ubq.script(source=scripts)
ubq = ubq.params(refresh=True) ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5) ubq = ubq.params(slices=5)
@ -294,7 +296,6 @@ class HuEs:
return False return False
def deleteByQuery(self, query, idxnm=""): def deleteByQuery(self, query, idxnm=""):
for i in range(3): for i in range(3):
try: try:
@ -392,7 +393,7 @@ class HuEs:
return rr return rr
def scrollIter(self, pagesize=100, scroll_time='2m', q={ def scrollIter(self, pagesize=100, scroll_time='2m', q={
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}): "query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
for _ in range(100): for _ in range(100):
try: try:
page = self.es.search( page = self.es.search(

View File

@ -4,6 +4,7 @@ from util import config
from minio import Minio from minio import Minio
from io import BytesIO from io import BytesIO
class HuMinio(object): class HuMinio(object):
def __init__(self, env): def __init__(self, env):
self.config = config.init(env) self.config = config.init(env)
@ -12,64 +13,62 @@ class HuMinio(object):
def __open__(self): def __open__(self):
try: try:
if self.conn:self.__close__() if self.conn:
self.__close__()
except Exception as e: except Exception as e:
pass pass
try: try:
self.conn = Minio(self.config.get("minio_host"), self.conn = Minio(self.config.get("minio_host"),
access_key=self.config.get("minio_usr"), access_key=self.config.get("minio_user"),
secret_key=self.config.get("minio_pwd"), secret_key=self.config.get("minio_password"),
secure=False secure=False
) )
except Exception as e: except Exception as e:
logging.error("Fail to connect %s "%self.config.get("minio_host") + str(e)) logging.error(
"Fail to connect %s " %
self.config.get("minio_host") + str(e))
def __close__(self): def __close__(self):
del self.conn del self.conn
self.conn = None self.conn = None
def put(self, bucket, fnm, binary): def put(self, bucket, fnm, binary):
for _ in range(10): for _ in range(10):
try: try:
if not self.conn.bucket_exists(bucket): if not self.conn.bucket_exists(bucket):
self.conn.make_bucket(bucket) self.conn.make_bucket(bucket)
r = self.conn.put_object(bucket, fnm, r = self.conn.put_object(bucket, fnm,
BytesIO(binary), BytesIO(binary),
len(binary) len(binary)
) )
return r return r
except Exception as e: except Exception as e:
logging.error(f"Fail put {bucket}/{fnm}: "+str(e)) logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
def get(self, bucket, fnm): def get(self, bucket, fnm):
for _ in range(10): for _ in range(10):
try: try:
r = self.conn.get_object(bucket, fnm) r = self.conn.get_object(bucket, fnm)
return r.read() return r.read()
except Exception as e: except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: "+str(e)) logging.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
def get_presigned_url(self, bucket, fnm, expires): def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10): for _ in range(10):
try: try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires) return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e: except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: "+str(e)) logging.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__() self.__open__()
time.sleep(1) time.sleep(1)
return return
if __name__ == "__main__": if __name__ == "__main__":
@ -78,9 +77,8 @@ if __name__ == "__main__":
from PIL import Image from PIL import Image
img = Image.open(fnm) img = Image.open(fnm)
buff = BytesIO() buff = BytesIO()
img.save(buff, format='JPEG') img.save(buff, format='JPEG')
print(conn.put("test", "11-408.jpg", buff.getvalue())) print(conn.put("test", "11-408.jpg", buff.getvalue()))
bts = conn.get("test", "11-408.jpg") bts = conn.get("test", "11-408.jpg")
img = Image.open(BytesIO(bts)) img = Image.open(BytesIO(bts))
img.save("test.jpg") img.save("test.jpg")