add llm API (#19)

* add llm API

* refine llm API
This commit is contained in:
KevinHuSh 2023-12-28 13:50:13 +08:00 committed by GitHub
parent cdd956568d
commit d0db329fef
17 changed files with 349 additions and 170 deletions

View File

@ -121,7 +121,6 @@
"match": "*_vec",
"mapping": {
"type": "dense_vector",
"dims": 1024,
"index": true,
"similarity": "cosine"
}

View File

@ -1,10 +1,9 @@
[infiniflow]
es=http://es01:9200
pgdb_usr=root
pgdb_pwd=infiniflow_docgpt
pgdb_host=postgres
pgdb_port=5432
postgres_user=root
postgres_password=infiniflow_docgpt
postgres_host=postgres
postgres_port=5432
minio_host=minio:9000
minio_usr=infiniflow
minio_pwd=infiniflow_docgpt
minio_user=infiniflow
minio_password=infiniflow_docgpt

View File

@ -1,2 +1,21 @@
from .embedding_model import HuEmbedding
from .chat_model import GptTurbo
import os
from .embedding_model import *
from .chat_model import *
from .cv_model import *
EmbeddingModel = None
ChatModel = None
CvModel = None
if os.environ.get("OPENAI_API_KEY"):
EmbeddingModel = GptEmbed()
ChatModel = GptTurbo()
CvModel = GptV4()
elif os.environ.get("DASHSCOPE_API_KEY"):
EmbeddingModel = QWenEmbd()
ChatModel = QWenChat()
CvModel = QWenCV()
else:
EmbeddingModel = HuEmbedding()

View File

@ -1,7 +1,8 @@
from abc import ABC
import openapi
from openai import OpenAI
import os
class Base(ABC):
def chat(self, system, history, gen_conf):
raise NotImplementedError("Please implement encode method!")
@ -9,26 +10,27 @@ class Base(ABC):
class GptTurbo(Base):
def __init__(self):
openapi.api_key = os.environ["OPENAPI_KEY"]
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def chat(self, system, history, gen_conf):
history.insert(0, {"role": "system", "content": system})
res = openapi.ChatCompletion.create(model="gpt-3.5-turbo",
messages=history,
**gen_conf)
res = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=history,
**gen_conf)
return res.choices[0].message.content.strip()
class QWen(Base):
class QWenChat(Base):
def chat(self, system, history, gen_conf):
from http import HTTPStatus
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
history.insert(0, {"role": "system", "content": system})
response = Generation.call(
Generation.Models.qwen_turbo,
messages=messages,
result_format='message'
Generation.Models.qwen_turbo,
messages=history,
result_format='message'
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']

66
python/llm/cv_model.py Normal file
View File

@ -0,0 +1,66 @@
from abc import ABC
from openai import OpenAI
import os
import base64
from io import BytesIO
class Base(ABC):
def describe(self, image, max_tokens=300):
raise NotImplementedError("Please implement encode method!")
def image2base64(self, image):
if isinstance(image, BytesIO):
return base64.b64encode(image.getvalue()).decode("utf-8")
buffered = BytesIO()
try:
image.save(buffered, format="JPEG")
except Exception as e:
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def prompt(self, b64):
return [
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
],
}
]
class GptV4(Base):
def __init__(self):
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
def describe(self, image, max_tokens=300):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
model="gpt-4-vision-preview",
messages=self.prompt(b64),
max_tokens=max_tokens,
)
return res.choices[0].message.content.strip()
class QWenCV(Base):
def describe(self, image, max_tokens=300):
from http import HTTPStatus
from dashscope import MultiModalConversation
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
messages=self.prompt(self.image2base64(image)))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
return response.message

View File

@ -1,8 +1,11 @@
from abc import ABC
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np
class Base(ABC):
def encode(self, texts: list, batch_size=32):
raise NotImplementedError("Please implement encode method!")
@ -22,11 +25,37 @@ class HuEmbedding(Base):
"""
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=torch.cuda.is_available())
def encode(self, texts: list, batch_size=32):
res = []
for i in range(0, len(texts), batch_size):
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
return np.array(res)
class GptEmbed(Base):
def __init__(self):
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model="text-embedding-ada-002")
return [d["embedding"] for d in res["data"]]
class QWenEmbd(Base):
def encode(self, texts: list, batch_size=32, text_type="document"):
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
import dashscope
from http import HTTPStatus
res = []
for txt in texts:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v2,
input=txt[:2048],
text_type=text_type
)
res.append(resp["output"]["embeddings"][0]["embedding"])
return res

View File

@ -372,7 +372,9 @@ class PptChunker(HuChunker):
def __call__(self, fnm):
from pptx import Presentation
ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm))
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
flds = self.Fields()
flds.text_chunks = []
for slide in ppt.slides:
@ -398,7 +400,8 @@ class TextChunker(HuChunker):
mime = magic.Magic(mime=True)
if isinstance(file_path, str):
file_type = mime.from_file(file_path)
else:file_type = mime.from_buffer(file_path)
else:
file_type = mime.from_buffer(file_path)
if 'text' in file_type:
return False
else:
@ -406,7 +409,8 @@ class TextChunker(HuChunker):
def __call__(self, fnm):
flds = self.Fields()
if self.is_binary_file(fnm):return flds
if self.is_binary_file(fnm):
return flds
with open(fnm, "r") as f:
txt = f.read()
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]

View File

@ -1,6 +1,6 @@
import re
from elasticsearch_dsl import Q,Search,A
from typing import List, Optional, Tuple,Dict, Union
from elasticsearch_dsl import Q, Search, A
from typing import List, Optional, Tuple, Dict, Union
from dataclasses import dataclass
from util import setup_logging, rmSpace
from nlp import huqie, query
@ -9,18 +9,24 @@ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
import numpy as np
from copy import deepcopy
def index_name(uid):return f"docgpt_{uid}"
def index_name(uid): return f"docgpt_{uid}"
class Dealer:
def __init__(self, es, emb_mdl):
self.qryr = query.EsQueryer(es)
self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
self.qryr.flds = [
"title_tks^10",
"title_sm_tks^5",
"content_ltks^2",
"content_sm_ltks"]
self.es = es
self.emb_mdl = emb_mdl
@dataclass
class SearchResult:
total:int
total: int
ids: List[str]
query_vector: List[float] = None
field: Optional[Dict] = None
@ -42,71 +48,78 @@ class Dealer:
keywords = []
qst = req.get("question", "")
bqry,keywords = self.qryr.question(qst)
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry, keywords = self.qryr.question(qst)
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry.filter.append(Q("exists", field="q_tks"))
bqry.boost = 0.05
print(bqry)
s = Search()
pg = int(req.get("page", 1))-1
pg = int(req.get("page", 1)) - 1
ps = int(req.get("size", 1000))
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
"image_id", "doc_id", "q_vec"])
s = s.query(bqry)[pg*ps:(pg+1)*ps]
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
s = s.highlight("content_ltks")
s = s.highlight("title_ltks")
if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
if not qst:
s = s.sort(
{"create_time": {"order": "desc", "unmapped_type": "date"}})
s = s.highlight_options(
fragment_size = 120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
fragment_size=120,
number_of_fragments=5,
boundary_scanner_locale="zh-CN",
boundary_scanner="SENTENCE",
boundary_chars=",./;:\\!(),。?:!……()——、"
)
s = s.to_dict()
q_vec = []
if req.get("vector"):
if req.get("vector"):
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
s["knn"]["filter"] = bqry.to_dict()
del s["highlight"]
q_vec = s["knn"]["query_vector"]
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
print("TOTAL: ", self.es.getTotal(res))
if self.es.getTotal(res) == 0 and "knn" in s:
bqry,_ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
bqry, _ = self.qryr.question(qst, min_match="10%")
if req.get("kb_ids"):
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
s["query"] = bqry.to_dict()
s["knn"]["filter"] = bqry.to_dict()
s["knn"]["similarity"] = 0.7
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
kwds = set([])
for k in keywords:
kwds.add(k)
for kk in huqie.qieqie(k).split(" "):
if len(kk) < 2:continue
if kk in kwds:continue
if len(kk) < 2:
continue
if kk in kwds:
continue
kwds.add(kk)
aggs = self.getAggregation(res, "docnm_kwd")
return self.SearchResult(
total = self.es.getTotal(res),
ids = self.es.getDocIds(res),
query_vector = q_vec,
aggregation = aggs,
highlight = self.getHighlight(res),
field = self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id","image_id", "doc_id", "q_vec"]),
keywords = list(kwds)
total=self.es.getTotal(res),
ids=self.es.getDocIds(res),
query_vector=q_vec,
aggregation=aggs,
highlight=self.getHighlight(res),
field=self.getFields(res, ["docnm_kwd", "content_ltks",
"kb_id", "image_id", "doc_id", "q_vec"]),
keywords=list(kwds)
)
def getAggregation(self, res, g):
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
bkts = res["aggregations"]["aggs_"+g]["buckets"]
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
return
bkts = res["aggregations"]["aggs_" + g]["buckets"]
return [(b["key"], b["doc_count"]) for b in bkts]
def getHighlight(self, res):
@ -114,8 +127,11 @@ class Dealer:
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
r = []
for t in line.split(" "):
if not t:continue
if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
if not t:
continue
if len(r) > 0 and len(
t) > 0 and r[-1][-1] in eng and t[0] in eng:
r.append(" ")
r.append(t)
r = "".join(r)
return r
@ -123,66 +139,76 @@ class Dealer:
ans = {}
for d in res["hits"]["hits"]:
hlts = d.get("highlight")
if not hlts:continue
if not hlts:
continue
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
return ans
def getFields(self, sres, flds):
res = {}
if not flds:return {}
for d in self.es.getSource(sres):
m = {n:d.get(n) for n in flds if d.get(n) is not None}
for n,v in m.items():
if type(v) == type([]):
if not flds:
return {}
for d in self.es.getSource(sres):
m = {n: d.get(n) for n in flds if d.get(n) is not None}
for n, v in m.items():
if isinstance(v, type([])):
m[n] = "\t".join([str(vv) for vv in v])
continue
if type(v) != type(""):m[n] = str(m[n])
if not isinstance(v, type("")):
m[n] = str(m[n])
m[n] = rmSpace(m[n])
if m:res[d["id"]] = m
if m:
res[d["id"]] = m
return res
@staticmethod
def trans2floats(txt):
return [float(t) for t in txt.split("\t")]
def insert_citations(self, ans, top_idx, sres,
vfield="q_vec", cfield="content_ltks"):
def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
ins_embd = [Dealer.trans2floats(
sres.field[sres.ids[i]][vfield]) for i in top_idx]
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
s = 0
e = 0
res = ""
def citeit():
nonlocal s, e, ans, res
if not ins_embd:return
if not ins_embd:
return
embd = self.emb_mdl.encode(ans[s: e])
sim = self.qryr.hybrid_similarity(embd,
ins_embd,
sim = self.qryr.hybrid_similarity(embd,
ins_embd,
huqie.qie(ans[s:e]).split(" "),
ins_tw)
print(ans[s: e], sim)
mx = np.max(sim)*0.99
if mx < 0.55:return
cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
for i in cita: res += f"@?{i}?@"
mx = np.max(sim) * 0.99
if mx < 0.55:
return
cita = list(set([top_idx[i]
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
for i in cita:
res += f"@?{i}?@"
return cita
punct = set(";。?!")
if not self.qryr.isChinese(ans):
if not self.qryr.isChinese(ans):
punct.add("?")
punct.add(".")
while e < len(ans):
if e - s < 12 or ans[e] not in punct:
e += 1
continue
if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
if ans[e] == "." and e + \
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
e += 1
continue
if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
e += 1
continue
res += ans[s: e]
@ -191,33 +217,36 @@ class Dealer:
e += 1
s = e
if s< len(ans):
if s < len(ans):
res += ans[s:]
citeit()
return res
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
if not ins_embd: return []
ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
#return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
vfield="q_vec", cfield="content_ltks"):
ins_embd = [
Dealer.trans2floats(
sres.field[i]["q_vec"]) for i in sres.ids]
if not ins_embd:
return []
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
sim = self.qryr.hybrid_similarity(sres.query_vector,
ins_embd,
huqie.qie(query).split(" "),
ins_tw, tkweight, vtweight)
return sim
if __name__ == "__main__":
if __name__ == "__main__":
from util import es_conn
SE = Dealer(es_conn.HuEs("infiniflow"))
qs = [
"胡凯",
""
]
for q in qs:
for q in qs:
print(">>>>>>>>>>>>>>>>>>>>", q)
print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
print(SE.search(
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))

View File

@ -5,8 +5,10 @@ from io import BytesIO
class HuExcelParser:
def __call__(self, fnm):
if isinstance(fnm, str):wb = load_workbook(fnm)
else: wb = load_workbook(BytesIO(fnm))
if isinstance(fnm, str):
wb = load_workbook(fnm)
else:
wb = load_workbook(BytesIO(fnm))
res = []
for sheetname in wb.sheetnames:
ws = wb[sheetname]

View File

@ -53,7 +53,7 @@ class HuParser:
def _y_dis(
self, a, b):
return (
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
def _match_proj(self, b):
proj_patt = [
@ -76,9 +76,9 @@ class HuParser:
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
tks_all = up["text"][-LEN:].strip() \
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
+ (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip()
tks_all = huqie.qie(tks_all).split(" ")
fea = [
up.get("R", -1) == down.get("R", -1),
@ -100,7 +100,7 @@ class HuParser:
True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[\(][^\)]+$", up["text"])
and re.search(r"[\)]", down["text"]) else False,
and re.search(r"[\)]", down["text"]) else False,
self._match_proj(down),
True if re.match(r"[A-Z]", down["text"]) else False,
True if re.match(r"[A-Z]", up["text"][-1]) else False,
@ -217,7 +217,7 @@ class HuParser:
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
tp, btm, x0, x1, b)
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
x0 != 0 and btm - tp != 0 else 0
x0 != 0 and btm - tp != 0 else 0
if ov > 0 and ratio:
ov /= (x1 - x0) * (btm - tp)
return ov
@ -382,7 +382,7 @@ class HuParser:
continue
for tb in tbls: # for table
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
left *= ZM
top *= ZM
right *= ZM
@ -899,7 +899,7 @@ class HuParser:
lst_r = rows[-1]
if lst_r[-1].get("R", "") != b.get("R", "") \
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
): # new row
): # new row
btm = b["bottom"]
b["rn"] += 1
rows.append([b])
@ -949,9 +949,9 @@ class HuParser:
j += 1
continue
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
[j - 1][0].get("text")) or j == 0
[j - 1][0].get("text")) or j == 0
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
if f and ff:
j += 1
continue
@ -1012,9 +1012,9 @@ class HuParser:
i += 1
continue
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
[jj][0].get("text")) or i == 0
[jj][0].get("text")) or i == 0
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
[jj][0].get("text")) or i + 1 >= len(tbl)
[jj][0].get("text")) or i + 1 >= len(tbl)
if f and ff:
i += 1
continue
@ -1169,8 +1169,8 @@ class HuParser:
else "") + headers[j - 1][k]
else:
headers[j][k] = headers[j - 1][k] \
+ ("" if headers[j - 1][k] else "") \
+ headers[j][k]
+ ("" if headers[j - 1][k] else "") \
+ headers[j][k]
logging.debug(
f">>>>>>>>>>>>>>>>>{cap}SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
@ -1247,7 +1247,7 @@ class HuParser:
i += 1
continue
lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"])
"-" + str(self.boxes[i]["layoutno"])
if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
"figure caption", "reference"]:
nomerge_lout_no.append(lst_lout_no)
@ -1526,7 +1526,8 @@ class HuParser:
return "\n\n".join(res)
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.lefted_chars = []
self.mean_height = []
self.mean_width = []
@ -1601,7 +1602,7 @@ class HuParser:
self.page_images[pns[0]].crop((left * ZM, top * ZM,
right *
ZM, min(
bottom, self.page_images[pns[0]].size[1])
bottom, self.page_images[pns[0]].size[1])
))
)
bottom -= self.page_images[pns[0]].size[1]

View File

@ -16,11 +16,12 @@ from io import BytesIO
from util import config
from timeit import default_timer as timer
from collections import OrderedDict
from llm import ChatModel, EmbeddingModel
SE = None
CFIELD="content_ltks"
EMBEDDING = HuEmbedding()
LLM = GptTurbo()
EMBEDDING = EmbeddingModel
LLM = ChatModel
def get_QA_pairs(hists):
pa = []

View File

@ -1,4 +1,4 @@
import json, os, sys, hashlib, copy, time, random, re, logging, torch
import json, os, sys, hashlib, copy, time, random, re
from os.path import dirname, realpath
sys.path.append(dirname(realpath(__file__)) + "/../")
from util.es_conn import HuEs
@ -7,10 +7,10 @@ from util.minio_conn import HuMinio
from util import rmSpace, findMaxDt
from FlagEmbedding import FlagModel
from nlp import huchunk, huqie, search
import base64, hashlib
from io import BytesIO
import pandas as pd
from elasticsearch_dsl import Q
from PIL import Image
from parser import (
PdfParser,
DocxParser,
@ -40,6 +40,15 @@ def chuck_doc(name, binary):
if suff.find("doc") >= 0: return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
if suff.find("ppt") >= 0: return PPT(binary)
if os.envirement.get("PARSE_IMAGE") \
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
name.lower()):
from llm import CvModel
txt = CvModel.describe(binary)
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return TextChunker()(binary)
@ -119,7 +128,6 @@ def build(row):
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
return []
print(row["doc_name"], obj)
if not obj.text_chunks and not obj.table_chunks:
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
return []
@ -146,7 +154,10 @@ def build(row):
if not img:
docs.append(d)
continue
img.save(output_buffer, format='JPEG')
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
else: output_buffer = BytesIO(img)
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
output_buffer.getvalue())
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])

View File

@ -1,19 +1,24 @@
import re
def rmSpace(txt):
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
def findMaxDt(fnm):
m = "1970-01-01 00:00:00"
try:
with open(fnm, "r") as f:
while True:
l = f.readline()
if not l:break
if not l:
break
l = l.strip("\n")
if l == 'nan':continue
if l > m:m = l
if l == 'nan':
continue
if l > m:
m = l
except Exception as e:
print("WARNING: can't find "+ fnm)
print("WARNING: can't find " + fnm)
return m

View File

@ -1,25 +1,31 @@
from configparser import ConfigParser
import os,inspect
from configparser import ConfigParser
import os
import inspect
CF = ConfigParser()
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
if not os.path.exists(__fnm):__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
assert os.path.exists(__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
if not os.path.exists(__fnm): __fnm = "./sys.cnf"
if not os.path.exists(__fnm):
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
assert os.path.exists(
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
if not os.path.exists(__fnm):
__fnm = "./sys.cnf"
CF.read(__fnm)
class Config:
def __init__(self, env):
self.env = env
if env == "spark":CF.read("./cv.cnf")
if env == "spark":
CF.read("./cv.cnf")
def get(self, key, default=None):
global CF
return os.environ.get(key.upper(), \
CF[self.env].get(key, default)
)
return os.environ.get(key.upper(),
CF[self.env].get(key, default)
)
def init(env):
return Config(env)

View File

@ -3,6 +3,7 @@ import time
from util import config
import pandas as pd
class Postgres(object):
def __init__(self, env, dbnm):
self.config = config.init(env)
@ -13,36 +14,42 @@ class Postgres(object):
def __open__(self):
import psycopg2
try:
if self.conn:self.__close__()
if self.conn:
self.__close__()
del self.conn
except Exception as e:
pass
try:
self.conn = psycopg2.connect(f"dbname={self.dbnm} user={self.config.get('pgdb_usr')} password={self.config.get('pgdb_pwd')} host={self.config.get('pgdb_host')} port={self.config.get('pgdb_port')}")
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
user={self.config.get('postgres_user')}
password={self.config.get('postgres_password')}
host={self.config.get('postgres_host')}
port={self.config.get('postgres_port')}""")
except Exception as e:
logging.error("Fail to connect %s "%self.config.get("pgdb_host") + str(e))
logging.error(
"Fail to connect %s " %
self.config.get("pgdb_host") + str(e))
def __close__(self):
try:
self.conn.close()
except Exception as e:
logging.error("Fail to close %s "%self.config.get("pgdb_host") + str(e))
logging.error(
"Fail to close %s " %
self.config.get("pgdb_host") + str(e))
def select(self, sql):
for _ in range(10):
try:
return pd.read_sql(sql, self.conn)
except Exception as e:
logging.error(f"Fail to exec {sql} "+str(e))
logging.error(f"Fail to exec {sql} " + str(e))
self.__open__()
time.sleep(1)
return pd.DataFrame()
def update(self, sql):
for _ in range(10):
try:
@ -53,11 +60,11 @@ class Postgres(object):
cur.close()
return updated_rows
except Exception as e:
logging.error(f"Fail to exec {sql} "+str(e))
logging.error(f"Fail to exec {sql} " + str(e))
self.__open__()
time.sleep(1)
return 0
if __name__ == "__main__":
Postgres("infiniflow", "docgpt")

View File

@ -228,7 +228,8 @@ class HuEs:
return False
def search(self, q, idxnm=None, src=False, timeout="2s"):
if not isinstance(q, dict): q = Search().query(q).to_dict()
if not isinstance(q, dict):
q = Search().query(q).to_dict()
for i in range(3):
try:
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
@ -274,9 +275,10 @@ class HuEs:
return False
def updateScriptByQuery(self, q, scripts, idxnm=None):
ubq = UpdateByQuery(index=self.idxnm if not idxnm else idxnm).using(self.es).query(q)
ubq = UpdateByQuery(
index=self.idxnm if not idxnm else idxnm).using(
self.es).query(q)
ubq = ubq.script(source=scripts)
ubq = ubq.params(refresh=True)
ubq = ubq.params(slices=5)
@ -294,7 +296,6 @@ class HuEs:
return False
def deleteByQuery(self, query, idxnm=""):
for i in range(3):
try:
@ -392,7 +393,7 @@ class HuEs:
return rr
def scrollIter(self, pagesize=100, scroll_time='2m', q={
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
for _ in range(100):
try:
page = self.es.search(

View File

@ -4,6 +4,7 @@ from util import config
from minio import Minio
from io import BytesIO
class HuMinio(object):
def __init__(self, env):
self.config = config.init(env)
@ -12,64 +13,62 @@ class HuMinio(object):
def __open__(self):
try:
if self.conn:self.__close__()
if self.conn:
self.__close__()
except Exception as e:
pass
try:
self.conn = Minio(self.config.get("minio_host"),
access_key=self.config.get("minio_usr"),
secret_key=self.config.get("minio_pwd"),
access_key=self.config.get("minio_user"),
secret_key=self.config.get("minio_password"),
secure=False
)
)
except Exception as e:
logging.error("Fail to connect %s "%self.config.get("minio_host") + str(e))
logging.error(
"Fail to connect %s " %
self.config.get("minio_host") + str(e))
def __close__(self):
del self.conn
self.conn = None
def put(self, bucket, fnm, binary):
for _ in range(10):
try:
if not self.conn.bucket_exists(bucket):
self.conn.make_bucket(bucket)
r = self.conn.put_object(bucket, fnm,
r = self.conn.put_object(bucket, fnm,
BytesIO(binary),
len(binary)
)
)
return r
except Exception as e:
logging.error(f"Fail put {bucket}/{fnm}: "+str(e))
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
def get(self, bucket, fnm):
for _ in range(10):
try:
r = self.conn.get_object(bucket, fnm)
return r.read()
except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: "+str(e))
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
return
def get_presigned_url(self, bucket, fnm, expires):
for _ in range(10):
try:
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
except Exception as e:
logging.error(f"fail get {bucket}/{fnm}: "+str(e))
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
self.__open__()
time.sleep(1)
return
return
if __name__ == "__main__":
@ -78,9 +77,8 @@ if __name__ == "__main__":
from PIL import Image
img = Image.open(fnm)
buff = BytesIO()
img.save(buff, format='JPEG')
img.save(buff, format='JPEG')
print(conn.put("test", "11-408.jpg", buff.getvalue()))
bts = conn.get("test", "11-408.jpg")
img = Image.open(BytesIO(bts))
img.save("test.jpg")