mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-07-31 00:43:37 +08:00
parent
cdd956568d
commit
d0db329fef
@ -121,7 +121,6 @@
|
|||||||
"match": "*_vec",
|
"match": "*_vec",
|
||||||
"mapping": {
|
"mapping": {
|
||||||
"type": "dense_vector",
|
"type": "dense_vector",
|
||||||
"dims": 1024,
|
|
||||||
"index": true,
|
"index": true,
|
||||||
"similarity": "cosine"
|
"similarity": "cosine"
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
[infiniflow]
|
[infiniflow]
|
||||||
es=http://es01:9200
|
es=http://es01:9200
|
||||||
pgdb_usr=root
|
postgres_user=root
|
||||||
pgdb_pwd=infiniflow_docgpt
|
postgres_password=infiniflow_docgpt
|
||||||
pgdb_host=postgres
|
postgres_host=postgres
|
||||||
pgdb_port=5432
|
postgres_port=5432
|
||||||
minio_host=minio:9000
|
minio_host=minio:9000
|
||||||
minio_usr=infiniflow
|
minio_user=infiniflow
|
||||||
minio_pwd=infiniflow_docgpt
|
minio_password=infiniflow_docgpt
|
||||||
|
|
||||||
|
@ -1,2 +1,21 @@
|
|||||||
from .embedding_model import HuEmbedding
|
import os
|
||||||
from .chat_model import GptTurbo
|
from .embedding_model import *
|
||||||
|
from .chat_model import *
|
||||||
|
from .cv_model import *
|
||||||
|
|
||||||
|
EmbeddingModel = None
|
||||||
|
ChatModel = None
|
||||||
|
CvModel = None
|
||||||
|
|
||||||
|
|
||||||
|
if os.environ.get("OPENAI_API_KEY"):
|
||||||
|
EmbeddingModel = GptEmbed()
|
||||||
|
ChatModel = GptTurbo()
|
||||||
|
CvModel = GptV4()
|
||||||
|
|
||||||
|
elif os.environ.get("DASHSCOPE_API_KEY"):
|
||||||
|
EmbeddingModel = QWenEmbd()
|
||||||
|
ChatModel = QWenChat()
|
||||||
|
CvModel = QWenCV()
|
||||||
|
else:
|
||||||
|
EmbeddingModel = HuEmbedding()
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import openapi
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
@ -9,26 +10,27 @@ class Base(ABC):
|
|||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
openapi.api_key = os.environ["OPENAPI_KEY"]
|
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
res = openapi.ChatCompletion.create(model="gpt-3.5-turbo",
|
res = self.client.chat.completions.create(
|
||||||
messages=history,
|
model="gpt-3.5-turbo",
|
||||||
**gen_conf)
|
messages=history,
|
||||||
|
**gen_conf)
|
||||||
return res.choices[0].message.content.strip()
|
return res.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
class QWen(Base):
|
class QWenChat(Base):
|
||||||
def chat(self, system, history, gen_conf):
|
def chat(self, system, history, gen_conf):
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from dashscope import Generation
|
from dashscope import Generation
|
||||||
from dashscope.api_entities.dashscope_response import Role
|
|
||||||
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||||
|
history.insert(0, {"role": "system", "content": system})
|
||||||
response = Generation.call(
|
response = Generation.call(
|
||||||
Generation.Models.qwen_turbo,
|
Generation.Models.qwen_turbo,
|
||||||
messages=messages,
|
messages=history,
|
||||||
result_format='message'
|
result_format='message'
|
||||||
)
|
)
|
||||||
if response.status_code == HTTPStatus.OK:
|
if response.status_code == HTTPStatus.OK:
|
||||||
return response.output.choices[0]['message']['content']
|
return response.output.choices[0]['message']['content']
|
||||||
|
66
python/llm/cv_model.py
Normal file
66
python/llm/cv_model.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class Base(ABC):
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
raise NotImplementedError("Please implement encode method!")
|
||||||
|
|
||||||
|
def image2base64(self, image):
|
||||||
|
if isinstance(image, BytesIO):
|
||||||
|
return base64.b64encode(image.getvalue()).decode("utf-8")
|
||||||
|
buffered = BytesIO()
|
||||||
|
try:
|
||||||
|
image.save(buffered, format="JPEG")
|
||||||
|
except Exception as e:
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
def prompt(self, b64):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/jpeg;base64,{b64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GptV4(Base):
|
||||||
|
def __init__(self):
|
||||||
|
self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
||||||
|
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
b64 = self.image2base64(image)
|
||||||
|
|
||||||
|
res = self.client.chat.completions.create(
|
||||||
|
model="gpt-4-vision-preview",
|
||||||
|
messages=self.prompt(b64),
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
return res.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class QWenCV(Base):
|
||||||
|
def describe(self, image, max_tokens=300):
|
||||||
|
from http import HTTPStatus
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||||
|
response = MultiModalConversation.call(model=MultiModalConversation.Models.qwen_vl_chat_v1,
|
||||||
|
messages=self.prompt(self.image2base64(image)))
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return response.output.choices[0]['message']['content']
|
||||||
|
return response.message
|
@ -1,8 +1,11 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from openai import OpenAI
|
||||||
from FlagEmbedding import FlagModel
|
from FlagEmbedding import FlagModel
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def encode(self, texts: list, batch_size=32):
|
def encode(self, texts: list, batch_size=32):
|
||||||
raise NotImplementedError("Please implement encode method!")
|
raise NotImplementedError("Please implement encode method!")
|
||||||
@ -22,11 +25,37 @@ class HuEmbedding(Base):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
self.model = FlagModel("BAAI/bge-large-zh-v1.5",
|
||||||
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
||||||
use_fp16=torch.cuda.is_available())
|
use_fp16=torch.cuda.is_available())
|
||||||
|
|
||||||
def encode(self, texts: list, batch_size=32):
|
def encode(self, texts: list, batch_size=32):
|
||||||
res = []
|
res = []
|
||||||
for i in range(0, len(texts), batch_size):
|
for i in range(0, len(texts), batch_size):
|
||||||
res.extend(self.model.encode(texts[i:i+batch_size]).tolist())
|
res.extend(self.model.encode(texts[i:i + batch_size]).tolist())
|
||||||
return np.array(res)
|
return np.array(res)
|
||||||
|
|
||||||
|
|
||||||
|
class GptEmbed(Base):
|
||||||
|
def __init__(self):
|
||||||
|
self.client = OpenAI(api_key=os.envirement["OPENAI_API_KEY"])
|
||||||
|
|
||||||
|
def encode(self, texts: list, batch_size=32):
|
||||||
|
res = self.client.embeddings.create(input=texts,
|
||||||
|
model="text-embedding-ada-002")
|
||||||
|
return [d["embedding"] for d in res["data"]]
|
||||||
|
|
||||||
|
|
||||||
|
class QWenEmbd(Base):
|
||||||
|
def encode(self, texts: list, batch_size=32, text_type="document"):
|
||||||
|
# export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
|
||||||
|
import dashscope
|
||||||
|
from http import HTTPStatus
|
||||||
|
res = []
|
||||||
|
for txt in texts:
|
||||||
|
resp = dashscope.TextEmbedding.call(
|
||||||
|
model=dashscope.TextEmbedding.Models.text_embedding_v2,
|
||||||
|
input=txt[:2048],
|
||||||
|
text_type=text_type
|
||||||
|
)
|
||||||
|
res.append(resp["output"]["embeddings"][0]["embedding"])
|
||||||
|
return res
|
||||||
|
@ -372,7 +372,9 @@ class PptChunker(HuChunker):
|
|||||||
|
|
||||||
def __call__(self, fnm):
|
def __call__(self, fnm):
|
||||||
from pptx import Presentation
|
from pptx import Presentation
|
||||||
ppt = Presentation(fnm) if isinstance(fnm, str) else Presentation(BytesIO(fnm))
|
ppt = Presentation(fnm) if isinstance(
|
||||||
|
fnm, str) else Presentation(
|
||||||
|
BytesIO(fnm))
|
||||||
flds = self.Fields()
|
flds = self.Fields()
|
||||||
flds.text_chunks = []
|
flds.text_chunks = []
|
||||||
for slide in ppt.slides:
|
for slide in ppt.slides:
|
||||||
@ -398,7 +400,8 @@ class TextChunker(HuChunker):
|
|||||||
mime = magic.Magic(mime=True)
|
mime = magic.Magic(mime=True)
|
||||||
if isinstance(file_path, str):
|
if isinstance(file_path, str):
|
||||||
file_type = mime.from_file(file_path)
|
file_type = mime.from_file(file_path)
|
||||||
else:file_type = mime.from_buffer(file_path)
|
else:
|
||||||
|
file_type = mime.from_buffer(file_path)
|
||||||
if 'text' in file_type:
|
if 'text' in file_type:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@ -406,7 +409,8 @@ class TextChunker(HuChunker):
|
|||||||
|
|
||||||
def __call__(self, fnm):
|
def __call__(self, fnm):
|
||||||
flds = self.Fields()
|
flds = self.Fields()
|
||||||
if self.is_binary_file(fnm):return flds
|
if self.is_binary_file(fnm):
|
||||||
|
return flds
|
||||||
with open(fnm, "r") as f:
|
with open(fnm, "r") as f:
|
||||||
txt = f.read()
|
txt = f.read()
|
||||||
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import re
|
import re
|
||||||
from elasticsearch_dsl import Q,Search,A
|
from elasticsearch_dsl import Q, Search, A
|
||||||
from typing import List, Optional, Tuple,Dict, Union
|
from typing import List, Optional, Tuple, Dict, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from util import setup_logging, rmSpace
|
from util import setup_logging, rmSpace
|
||||||
from nlp import huqie, query
|
from nlp import huqie, query
|
||||||
@ -9,18 +9,24 @@ from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
def index_name(uid):return f"docgpt_{uid}"
|
|
||||||
|
def index_name(uid): return f"docgpt_{uid}"
|
||||||
|
|
||||||
|
|
||||||
class Dealer:
|
class Dealer:
|
||||||
def __init__(self, es, emb_mdl):
|
def __init__(self, es, emb_mdl):
|
||||||
self.qryr = query.EsQueryer(es)
|
self.qryr = query.EsQueryer(es)
|
||||||
self.qryr.flds = ["title_tks^10", "title_sm_tks^5", "content_ltks^2", "content_sm_ltks"]
|
self.qryr.flds = [
|
||||||
|
"title_tks^10",
|
||||||
|
"title_sm_tks^5",
|
||||||
|
"content_ltks^2",
|
||||||
|
"content_sm_ltks"]
|
||||||
self.es = es
|
self.es = es
|
||||||
self.emb_mdl = emb_mdl
|
self.emb_mdl = emb_mdl
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SearchResult:
|
class SearchResult:
|
||||||
total:int
|
total: int
|
||||||
ids: List[str]
|
ids: List[str]
|
||||||
query_vector: List[float] = None
|
query_vector: List[float] = None
|
||||||
field: Optional[Dict] = None
|
field: Optional[Dict] = None
|
||||||
@ -42,71 +48,78 @@ class Dealer:
|
|||||||
keywords = []
|
keywords = []
|
||||||
qst = req.get("question", "")
|
qst = req.get("question", "")
|
||||||
|
|
||||||
bqry,keywords = self.qryr.question(qst)
|
bqry, keywords = self.qryr.question(qst)
|
||||||
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
if req.get("kb_ids"):
|
||||||
|
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||||
bqry.filter.append(Q("exists", field="q_tks"))
|
bqry.filter.append(Q("exists", field="q_tks"))
|
||||||
bqry.boost = 0.05
|
bqry.boost = 0.05
|
||||||
print(bqry)
|
print(bqry)
|
||||||
|
|
||||||
s = Search()
|
s = Search()
|
||||||
pg = int(req.get("page", 1))-1
|
pg = int(req.get("page", 1)) - 1
|
||||||
ps = int(req.get("size", 1000))
|
ps = int(req.get("size", 1000))
|
||||||
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
src = req.get("field", ["docnm_kwd", "content_ltks", "kb_id",
|
||||||
"image_id", "doc_id", "q_vec"])
|
"image_id", "doc_id", "q_vec"])
|
||||||
|
|
||||||
s = s.query(bqry)[pg*ps:(pg+1)*ps]
|
s = s.query(bqry)[pg * ps:(pg + 1) * ps]
|
||||||
s = s.highlight("content_ltks")
|
s = s.highlight("content_ltks")
|
||||||
s = s.highlight("title_ltks")
|
s = s.highlight("title_ltks")
|
||||||
if not qst: s = s.sort({"create_time":{"order":"desc", "unmapped_type":"date"}})
|
if not qst:
|
||||||
|
s = s.sort(
|
||||||
|
{"create_time": {"order": "desc", "unmapped_type": "date"}})
|
||||||
|
|
||||||
s = s.highlight_options(
|
s = s.highlight_options(
|
||||||
fragment_size = 120,
|
fragment_size=120,
|
||||||
number_of_fragments=5,
|
number_of_fragments=5,
|
||||||
boundary_scanner_locale="zh-CN",
|
boundary_scanner_locale="zh-CN",
|
||||||
boundary_scanner="SENTENCE",
|
boundary_scanner="SENTENCE",
|
||||||
boundary_chars=",./;:\\!(),。?:!……()——、"
|
boundary_chars=",./;:\\!(),。?:!……()——、"
|
||||||
)
|
)
|
||||||
s = s.to_dict()
|
s = s.to_dict()
|
||||||
q_vec = []
|
q_vec = []
|
||||||
if req.get("vector"):
|
if req.get("vector"):
|
||||||
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
s["knn"] = self._vector(qst, req.get("similarity", 0.4), ps)
|
||||||
s["knn"]["filter"] = bqry.to_dict()
|
s["knn"]["filter"] = bqry.to_dict()
|
||||||
del s["highlight"]
|
del s["highlight"]
|
||||||
q_vec = s["knn"]["query_vector"]
|
q_vec = s["knn"]["query_vector"]
|
||||||
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
|
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
||||||
print("TOTAL: ", self.es.getTotal(res))
|
print("TOTAL: ", self.es.getTotal(res))
|
||||||
if self.es.getTotal(res) == 0 and "knn" in s:
|
if self.es.getTotal(res) == 0 and "knn" in s:
|
||||||
bqry,_ = self.qryr.question(qst, min_match="10%")
|
bqry, _ = self.qryr.question(qst, min_match="10%")
|
||||||
if req.get("kb_ids"): bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
if req.get("kb_ids"):
|
||||||
|
bqry.filter.append(Q("terms", kb_id=req["kb_ids"]))
|
||||||
s["query"] = bqry.to_dict()
|
s["query"] = bqry.to_dict()
|
||||||
s["knn"]["filter"] = bqry.to_dict()
|
s["knn"]["filter"] = bqry.to_dict()
|
||||||
s["knn"]["similarity"] = 0.7
|
s["knn"]["similarity"] = 0.7
|
||||||
res = self.es.search(s, idxnm=idxnm, timeout="600s",src=src)
|
res = self.es.search(s, idxnm=idxnm, timeout="600s", src=src)
|
||||||
|
|
||||||
kwds = set([])
|
kwds = set([])
|
||||||
for k in keywords:
|
for k in keywords:
|
||||||
kwds.add(k)
|
kwds.add(k)
|
||||||
for kk in huqie.qieqie(k).split(" "):
|
for kk in huqie.qieqie(k).split(" "):
|
||||||
if len(kk) < 2:continue
|
if len(kk) < 2:
|
||||||
if kk in kwds:continue
|
continue
|
||||||
|
if kk in kwds:
|
||||||
|
continue
|
||||||
kwds.add(kk)
|
kwds.add(kk)
|
||||||
|
|
||||||
aggs = self.getAggregation(res, "docnm_kwd")
|
aggs = self.getAggregation(res, "docnm_kwd")
|
||||||
|
|
||||||
return self.SearchResult(
|
return self.SearchResult(
|
||||||
total = self.es.getTotal(res),
|
total=self.es.getTotal(res),
|
||||||
ids = self.es.getDocIds(res),
|
ids=self.es.getDocIds(res),
|
||||||
query_vector = q_vec,
|
query_vector=q_vec,
|
||||||
aggregation = aggs,
|
aggregation=aggs,
|
||||||
highlight = self.getHighlight(res),
|
highlight=self.getHighlight(res),
|
||||||
field = self.getFields(res, ["docnm_kwd", "content_ltks",
|
field=self.getFields(res, ["docnm_kwd", "content_ltks",
|
||||||
"kb_id","image_id", "doc_id", "q_vec"]),
|
"kb_id", "image_id", "doc_id", "q_vec"]),
|
||||||
keywords = list(kwds)
|
keywords=list(kwds)
|
||||||
)
|
)
|
||||||
|
|
||||||
def getAggregation(self, res, g):
|
def getAggregation(self, res, g):
|
||||||
if not "aggregations" in res or "aggs_"+g not in res["aggregations"]:return
|
if not "aggregations" in res or "aggs_" + g not in res["aggregations"]:
|
||||||
bkts = res["aggregations"]["aggs_"+g]["buckets"]
|
return
|
||||||
|
bkts = res["aggregations"]["aggs_" + g]["buckets"]
|
||||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||||
|
|
||||||
def getHighlight(self, res):
|
def getHighlight(self, res):
|
||||||
@ -114,8 +127,11 @@ class Dealer:
|
|||||||
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
eng = set(list("qwertyuioplkjhgfdsazxcvbnm"))
|
||||||
r = []
|
r = []
|
||||||
for t in line.split(" "):
|
for t in line.split(" "):
|
||||||
if not t:continue
|
if not t:
|
||||||
if len(r)>0 and len(t)>0 and r[-1][-1] in eng and t[0] in eng:r.append(" ")
|
continue
|
||||||
|
if len(r) > 0 and len(
|
||||||
|
t) > 0 and r[-1][-1] in eng and t[0] in eng:
|
||||||
|
r.append(" ")
|
||||||
r.append(t)
|
r.append(t)
|
||||||
r = "".join(r)
|
r = "".join(r)
|
||||||
return r
|
return r
|
||||||
@ -123,66 +139,76 @@ class Dealer:
|
|||||||
ans = {}
|
ans = {}
|
||||||
for d in res["hits"]["hits"]:
|
for d in res["hits"]["hits"]:
|
||||||
hlts = d.get("highlight")
|
hlts = d.get("highlight")
|
||||||
if not hlts:continue
|
if not hlts:
|
||||||
|
continue
|
||||||
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
ans[d["_id"]] = "".join([a for a in list(hlts.items())[0][1]])
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def getFields(self, sres, flds):
|
def getFields(self, sres, flds):
|
||||||
res = {}
|
res = {}
|
||||||
if not flds:return {}
|
if not flds:
|
||||||
for d in self.es.getSource(sres):
|
return {}
|
||||||
m = {n:d.get(n) for n in flds if d.get(n) is not None}
|
for d in self.es.getSource(sres):
|
||||||
for n,v in m.items():
|
m = {n: d.get(n) for n in flds if d.get(n) is not None}
|
||||||
if type(v) == type([]):
|
for n, v in m.items():
|
||||||
|
if isinstance(v, type([])):
|
||||||
m[n] = "\t".join([str(vv) for vv in v])
|
m[n] = "\t".join([str(vv) for vv in v])
|
||||||
continue
|
continue
|
||||||
if type(v) != type(""):m[n] = str(m[n])
|
if not isinstance(v, type("")):
|
||||||
|
m[n] = str(m[n])
|
||||||
m[n] = rmSpace(m[n])
|
m[n] = rmSpace(m[n])
|
||||||
|
|
||||||
if m:res[d["id"]] = m
|
if m:
|
||||||
|
res[d["id"]] = m
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trans2floats(txt):
|
def trans2floats(txt):
|
||||||
return [float(t) for t in txt.split("\t")]
|
return [float(t) for t in txt.split("\t")]
|
||||||
|
|
||||||
|
def insert_citations(self, ans, top_idx, sres,
|
||||||
|
vfield="q_vec", cfield="content_ltks"):
|
||||||
|
|
||||||
def insert_citations(self, ans, top_idx, sres, vfield = "q_vec", cfield="content_ltks"):
|
ins_embd = [Dealer.trans2floats(
|
||||||
|
sres.field[sres.ids[i]][vfield]) for i in top_idx]
|
||||||
ins_embd = [Dealer.trans2floats(sres.field[sres.ids[i]][vfield]) for i in top_idx]
|
ins_tw = [sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
||||||
ins_tw =[sres.field[sres.ids[i]][cfield].split(" ") for i in top_idx]
|
|
||||||
s = 0
|
s = 0
|
||||||
e = 0
|
e = 0
|
||||||
res = ""
|
res = ""
|
||||||
|
|
||||||
def citeit():
|
def citeit():
|
||||||
nonlocal s, e, ans, res
|
nonlocal s, e, ans, res
|
||||||
if not ins_embd:return
|
if not ins_embd:
|
||||||
|
return
|
||||||
embd = self.emb_mdl.encode(ans[s: e])
|
embd = self.emb_mdl.encode(ans[s: e])
|
||||||
sim = self.qryr.hybrid_similarity(embd,
|
sim = self.qryr.hybrid_similarity(embd,
|
||||||
ins_embd,
|
ins_embd,
|
||||||
huqie.qie(ans[s:e]).split(" "),
|
huqie.qie(ans[s:e]).split(" "),
|
||||||
ins_tw)
|
ins_tw)
|
||||||
print(ans[s: e], sim)
|
print(ans[s: e], sim)
|
||||||
mx = np.max(sim)*0.99
|
mx = np.max(sim) * 0.99
|
||||||
if mx < 0.55:return
|
if mx < 0.55:
|
||||||
cita = list(set([top_idx[i] for i in range(len(ins_embd)) if sim[i] >mx]))[:4]
|
return
|
||||||
for i in cita: res += f"@?{i}?@"
|
cita = list(set([top_idx[i]
|
||||||
|
for i in range(len(ins_embd)) if sim[i] > mx]))[:4]
|
||||||
|
for i in cita:
|
||||||
|
res += f"@?{i}?@"
|
||||||
|
|
||||||
return cita
|
return cita
|
||||||
|
|
||||||
punct = set(";。?!!")
|
punct = set(";。?!!")
|
||||||
if not self.qryr.isChinese(ans):
|
if not self.qryr.isChinese(ans):
|
||||||
punct.add("?")
|
punct.add("?")
|
||||||
punct.add(".")
|
punct.add(".")
|
||||||
while e < len(ans):
|
while e < len(ans):
|
||||||
if e - s < 12 or ans[e] not in punct:
|
if e - s < 12 or ans[e] not in punct:
|
||||||
e += 1
|
e += 1
|
||||||
continue
|
continue
|
||||||
if ans[e] == "." and e+1<len(ans) and re.match(r"[0-9]", ans[e+1]):
|
if ans[e] == "." and e + \
|
||||||
|
1 < len(ans) and re.match(r"[0-9]", ans[e + 1]):
|
||||||
e += 1
|
e += 1
|
||||||
continue
|
continue
|
||||||
if ans[e] == "." and e-2>=0 and ans[e-2] == "\n":
|
if ans[e] == "." and e - 2 >= 0 and ans[e - 2] == "\n":
|
||||||
e += 1
|
e += 1
|
||||||
continue
|
continue
|
||||||
res += ans[s: e]
|
res += ans[s: e]
|
||||||
@ -191,33 +217,36 @@ class Dealer:
|
|||||||
e += 1
|
e += 1
|
||||||
s = e
|
s = e
|
||||||
|
|
||||||
if s< len(ans):
|
if s < len(ans):
|
||||||
res += ans[s:]
|
res += ans[s:]
|
||||||
citeit()
|
citeit()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7,
|
||||||
def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, vfield="q_vec", cfield="content_ltks"):
|
vfield="q_vec", cfield="content_ltks"):
|
||||||
ins_embd = [Dealer.trans2floats(sres.field[i]["q_vec"]) for i in sres.ids]
|
ins_embd = [
|
||||||
if not ins_embd: return []
|
Dealer.trans2floats(
|
||||||
ins_tw =[sres.field[i][cfield].split(" ") for i in sres.ids]
|
sres.field[i]["q_vec"]) for i in sres.ids]
|
||||||
#return CosineSimilarity([sres.query_vector], ins_embd)[0]
|
if not ins_embd:
|
||||||
sim = self.qryr.hybrid_similarity(sres.query_vector,
|
return []
|
||||||
ins_embd,
|
ins_tw = [sres.field[i][cfield].split(" ") for i in sres.ids]
|
||||||
|
# return CosineSimilarity([sres.query_vector], ins_embd)[0]
|
||||||
|
sim = self.qryr.hybrid_similarity(sres.query_vector,
|
||||||
|
ins_embd,
|
||||||
huqie.qie(query).split(" "),
|
huqie.qie(query).split(" "),
|
||||||
ins_tw, tkweight, vtweight)
|
ins_tw, tkweight, vtweight)
|
||||||
return sim
|
return sim
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == "__main__":
|
|
||||||
from util import es_conn
|
from util import es_conn
|
||||||
SE = Dealer(es_conn.HuEs("infiniflow"))
|
SE = Dealer(es_conn.HuEs("infiniflow"))
|
||||||
qs = [
|
qs = [
|
||||||
"胡凯",
|
"胡凯",
|
||||||
""
|
""
|
||||||
]
|
]
|
||||||
for q in qs:
|
for q in qs:
|
||||||
print(">>>>>>>>>>>>>>>>>>>>", q)
|
print(">>>>>>>>>>>>>>>>>>>>", q)
|
||||||
print(SE.search({"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
|
print(SE.search(
|
||||||
|
{"question": q, "kb_ids": "64f072a75f3b97c865718c4a"}, "infiniflow_*"))
|
||||||
|
@ -5,8 +5,10 @@ from io import BytesIO
|
|||||||
|
|
||||||
class HuExcelParser:
|
class HuExcelParser:
|
||||||
def __call__(self, fnm):
|
def __call__(self, fnm):
|
||||||
if isinstance(fnm, str):wb = load_workbook(fnm)
|
if isinstance(fnm, str):
|
||||||
else: wb = load_workbook(BytesIO(fnm))
|
wb = load_workbook(fnm)
|
||||||
|
else:
|
||||||
|
wb = load_workbook(BytesIO(fnm))
|
||||||
res = []
|
res = []
|
||||||
for sheetname in wb.sheetnames:
|
for sheetname in wb.sheetnames:
|
||||||
ws = wb[sheetname]
|
ws = wb[sheetname]
|
||||||
|
@ -53,7 +53,7 @@ class HuParser:
|
|||||||
def _y_dis(
|
def _y_dis(
|
||||||
self, a, b):
|
self, a, b):
|
||||||
return (
|
return (
|
||||||
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
||||||
|
|
||||||
def _match_proj(self, b):
|
def _match_proj(self, b):
|
||||||
proj_patt = [
|
proj_patt = [
|
||||||
@ -76,9 +76,9 @@ class HuParser:
|
|||||||
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
||||||
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
||||||
tks_all = up["text"][-LEN:].strip() \
|
tks_all = up["text"][-LEN:].strip() \
|
||||||
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
||||||
up["text"][-1] + down["text"][0]) else "") \
|
up["text"][-1] + down["text"][0]) else "") \
|
||||||
+ down["text"][:LEN].strip()
|
+ down["text"][:LEN].strip()
|
||||||
tks_all = huqie.qie(tks_all).split(" ")
|
tks_all = huqie.qie(tks_all).split(" ")
|
||||||
fea = [
|
fea = [
|
||||||
up.get("R", -1) == down.get("R", -1),
|
up.get("R", -1) == down.get("R", -1),
|
||||||
@ -100,7 +100,7 @@ class HuParser:
|
|||||||
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
||||||
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
||||||
True if re.search(r"[\((][^\))]+$", up["text"])
|
True if re.search(r"[\((][^\))]+$", up["text"])
|
||||||
and re.search(r"[\))]", down["text"]) else False,
|
and re.search(r"[\))]", down["text"]) else False,
|
||||||
self._match_proj(down),
|
self._match_proj(down),
|
||||||
True if re.match(r"[A-Z]", down["text"]) else False,
|
True if re.match(r"[A-Z]", down["text"]) else False,
|
||||||
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
||||||
@ -217,7 +217,7 @@ class HuParser:
|
|||||||
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
||||||
tp, btm, x0, x1, b)
|
tp, btm, x0, x1, b)
|
||||||
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
||||||
x0 != 0 and btm - tp != 0 else 0
|
x0 != 0 and btm - tp != 0 else 0
|
||||||
if ov > 0 and ratio:
|
if ov > 0 and ratio:
|
||||||
ov /= (x1 - x0) * (btm - tp)
|
ov /= (x1 - x0) * (btm - tp)
|
||||||
return ov
|
return ov
|
||||||
@ -382,7 +382,7 @@ class HuParser:
|
|||||||
continue
|
continue
|
||||||
for tb in tbls: # for table
|
for tb in tbls: # for table
|
||||||
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
||||||
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
||||||
left *= ZM
|
left *= ZM
|
||||||
top *= ZM
|
top *= ZM
|
||||||
right *= ZM
|
right *= ZM
|
||||||
@ -899,7 +899,7 @@ class HuParser:
|
|||||||
lst_r = rows[-1]
|
lst_r = rows[-1]
|
||||||
if lst_r[-1].get("R", "") != b.get("R", "") \
|
if lst_r[-1].get("R", "") != b.get("R", "") \
|
||||||
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
||||||
): # new row
|
): # new row
|
||||||
btm = b["bottom"]
|
btm = b["bottom"]
|
||||||
b["rn"] += 1
|
b["rn"] += 1
|
||||||
rows.append([b])
|
rows.append([b])
|
||||||
@ -949,9 +949,9 @@ class HuParser:
|
|||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
||||||
[j - 1][0].get("text")) or j == 0
|
[j - 1][0].get("text")) or j == 0
|
||||||
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
||||||
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
||||||
if f and ff:
|
if f and ff:
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
@ -1012,9 +1012,9 @@ class HuParser:
|
|||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
||||||
[jj][0].get("text")) or i == 0
|
[jj][0].get("text")) or i == 0
|
||||||
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
||||||
[jj][0].get("text")) or i + 1 >= len(tbl)
|
[jj][0].get("text")) or i + 1 >= len(tbl)
|
||||||
if f and ff:
|
if f and ff:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
@ -1169,8 +1169,8 @@ class HuParser:
|
|||||||
else "") + headers[j - 1][k]
|
else "") + headers[j - 1][k]
|
||||||
else:
|
else:
|
||||||
headers[j][k] = headers[j - 1][k] \
|
headers[j][k] = headers[j - 1][k] \
|
||||||
+ ("的" if headers[j - 1][k] else "") \
|
+ ("的" if headers[j - 1][k] else "") \
|
||||||
+ headers[j][k]
|
+ headers[j][k]
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
||||||
@ -1247,7 +1247,7 @@ class HuParser:
|
|||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
lout_no = str(self.boxes[i]["page_number"]) + \
|
lout_no = str(self.boxes[i]["page_number"]) + \
|
||||||
"-" + str(self.boxes[i]["layoutno"])
|
"-" + str(self.boxes[i]["layoutno"])
|
||||||
if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
|
if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
|
||||||
"figure caption", "reference"]:
|
"figure caption", "reference"]:
|
||||||
nomerge_lout_no.append(lst_lout_no)
|
nomerge_lout_no.append(lst_lout_no)
|
||||||
@ -1526,7 +1526,8 @@ class HuParser:
|
|||||||
return "\n\n".join(res)
|
return "\n\n".join(res)
|
||||||
|
|
||||||
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
|
||||||
self.pdf = pdfplumber.open(fnm) if isinstance(fnm, str) else pdfplumber.open(BytesIO(fnm))
|
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||||
|
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||||
self.lefted_chars = []
|
self.lefted_chars = []
|
||||||
self.mean_height = []
|
self.mean_height = []
|
||||||
self.mean_width = []
|
self.mean_width = []
|
||||||
@ -1601,7 +1602,7 @@ class HuParser:
|
|||||||
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
||||||
right *
|
right *
|
||||||
ZM, min(
|
ZM, min(
|
||||||
bottom, self.page_images[pns[0]].size[1])
|
bottom, self.page_images[pns[0]].size[1])
|
||||||
))
|
))
|
||||||
)
|
)
|
||||||
bottom -= self.page_images[pns[0]].size[1]
|
bottom -= self.page_images[pns[0]].size[1]
|
||||||
|
@ -16,11 +16,12 @@ from io import BytesIO
|
|||||||
from util import config
|
from util import config
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from llm import ChatModel, EmbeddingModel
|
||||||
|
|
||||||
SE = None
|
SE = None
|
||||||
CFIELD="content_ltks"
|
CFIELD="content_ltks"
|
||||||
EMBEDDING = HuEmbedding()
|
EMBEDDING = EmbeddingModel
|
||||||
LLM = GptTurbo()
|
LLM = ChatModel
|
||||||
|
|
||||||
def get_QA_pairs(hists):
|
def get_QA_pairs(hists):
|
||||||
pa = []
|
pa = []
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import json, os, sys, hashlib, copy, time, random, re, logging, torch
|
import json, os, sys, hashlib, copy, time, random, re
|
||||||
from os.path import dirname, realpath
|
from os.path import dirname, realpath
|
||||||
sys.path.append(dirname(realpath(__file__)) + "/../")
|
sys.path.append(dirname(realpath(__file__)) + "/../")
|
||||||
from util.es_conn import HuEs
|
from util.es_conn import HuEs
|
||||||
@ -7,10 +7,10 @@ from util.minio_conn import HuMinio
|
|||||||
from util import rmSpace, findMaxDt
|
from util import rmSpace, findMaxDt
|
||||||
from FlagEmbedding import FlagModel
|
from FlagEmbedding import FlagModel
|
||||||
from nlp import huchunk, huqie, search
|
from nlp import huchunk, huqie, search
|
||||||
import base64, hashlib
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from elasticsearch_dsl import Q
|
from elasticsearch_dsl import Q
|
||||||
|
from PIL import Image
|
||||||
from parser import (
|
from parser import (
|
||||||
PdfParser,
|
PdfParser,
|
||||||
DocxParser,
|
DocxParser,
|
||||||
@ -40,6 +40,15 @@ def chuck_doc(name, binary):
|
|||||||
if suff.find("doc") >= 0: return DOC(binary)
|
if suff.find("doc") >= 0: return DOC(binary)
|
||||||
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff): return EXC(binary)
|
||||||
if suff.find("ppt") >= 0: return PPT(binary)
|
if suff.find("ppt") >= 0: return PPT(binary)
|
||||||
|
if os.envirement.get("PARSE_IMAGE") \
|
||||||
|
and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
|
||||||
|
name.lower()):
|
||||||
|
from llm import CvModel
|
||||||
|
txt = CvModel.describe(binary)
|
||||||
|
field = TextChunker.Fields()
|
||||||
|
field.text_chunks = [(txt, binary)]
|
||||||
|
field.table_chunks = []
|
||||||
|
|
||||||
|
|
||||||
return TextChunker()(binary)
|
return TextChunker()(binary)
|
||||||
|
|
||||||
@ -119,7 +128,6 @@ def build(row):
|
|||||||
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
set_progress(row["kb2doc_id"], -1, f"Internal system error: %s"%str(e).replace("'", ""))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
print(row["doc_name"], obj)
|
|
||||||
if not obj.text_chunks and not obj.table_chunks:
|
if not obj.text_chunks and not obj.table_chunks:
|
||||||
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
set_progress(row["kb2doc_id"], 1, "Nothing added! Mostly, file type unsupported yet.")
|
||||||
return []
|
return []
|
||||||
@ -146,7 +154,10 @@ def build(row):
|
|||||||
if not img:
|
if not img:
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
continue
|
continue
|
||||||
img.save(output_buffer, format='JPEG')
|
|
||||||
|
if isinstance(img, Image): img.save(output_buffer, format='JPEG')
|
||||||
|
else: output_buffer = BytesIO(img)
|
||||||
|
|
||||||
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
MINIO.put("{}-{}".format(row["uid"], row["kb_id"]), d["_id"],
|
||||||
output_buffer.getvalue())
|
output_buffer.getvalue())
|
||||||
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
d["img_id"] = "{}-{}".format(row["uid"], row["kb_id"])
|
||||||
|
@ -1,19 +1,24 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
def rmSpace(txt):
|
def rmSpace(txt):
|
||||||
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
txt = re.sub(r"([^a-z0-9.,]) +([^ ])", r"\1\2", txt)
|
||||||
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
return re.sub(r"([^ ]) +([^a-z0-9.,])", r"\1\2", txt)
|
||||||
|
|
||||||
|
|
||||||
def findMaxDt(fnm):
|
def findMaxDt(fnm):
|
||||||
m = "1970-01-01 00:00:00"
|
m = "1970-01-01 00:00:00"
|
||||||
try:
|
try:
|
||||||
with open(fnm, "r") as f:
|
with open(fnm, "r") as f:
|
||||||
while True:
|
while True:
|
||||||
l = f.readline()
|
l = f.readline()
|
||||||
if not l:break
|
if not l:
|
||||||
|
break
|
||||||
l = l.strip("\n")
|
l = l.strip("\n")
|
||||||
if l == 'nan':continue
|
if l == 'nan':
|
||||||
if l > m:m = l
|
continue
|
||||||
|
if l > m:
|
||||||
|
m = l
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("WARNING: can't find "+ fnm)
|
print("WARNING: can't find " + fnm)
|
||||||
return m
|
return m
|
||||||
|
@ -1,25 +1,31 @@
|
|||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
import os,inspect
|
import os
|
||||||
|
import inspect
|
||||||
|
|
||||||
CF = ConfigParser()
|
CF = ConfigParser()
|
||||||
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
__fnm = os.path.join(os.path.dirname(__file__), '../conf/sys.cnf')
|
||||||
if not os.path.exists(__fnm):__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
|
if not os.path.exists(__fnm):
|
||||||
assert os.path.exists(__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
|
__fnm = os.path.join(os.path.dirname(__file__), '../../conf/sys.cnf')
|
||||||
if not os.path.exists(__fnm): __fnm = "./sys.cnf"
|
assert os.path.exists(
|
||||||
|
__fnm), f"【EXCEPTION】can't find {__fnm}." + os.path.dirname(__file__)
|
||||||
|
if not os.path.exists(__fnm):
|
||||||
|
__fnm = "./sys.cnf"
|
||||||
|
|
||||||
CF.read(__fnm)
|
CF.read(__fnm)
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self.env = env
|
self.env = env
|
||||||
if env == "spark":CF.read("./cv.cnf")
|
if env == "spark":
|
||||||
|
CF.read("./cv.cnf")
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
global CF
|
global CF
|
||||||
return os.environ.get(key.upper(), \
|
return os.environ.get(key.upper(),
|
||||||
CF[self.env].get(key, default)
|
CF[self.env].get(key, default)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def init(env):
|
def init(env):
|
||||||
return Config(env)
|
return Config(env)
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import time
|
|||||||
from util import config
|
from util import config
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class Postgres(object):
|
class Postgres(object):
|
||||||
def __init__(self, env, dbnm):
|
def __init__(self, env, dbnm):
|
||||||
self.config = config.init(env)
|
self.config = config.init(env)
|
||||||
@ -13,36 +14,42 @@ class Postgres(object):
|
|||||||
def __open__(self):
|
def __open__(self):
|
||||||
import psycopg2
|
import psycopg2
|
||||||
try:
|
try:
|
||||||
if self.conn:self.__close__()
|
if self.conn:
|
||||||
|
self.__close__()
|
||||||
del self.conn
|
del self.conn
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = psycopg2.connect(f"dbname={self.dbnm} user={self.config.get('pgdb_usr')} password={self.config.get('pgdb_pwd')} host={self.config.get('pgdb_host')} port={self.config.get('pgdb_port')}")
|
self.conn = psycopg2.connect(f"""dbname={self.dbnm}
|
||||||
|
user={self.config.get('postgres_user')}
|
||||||
|
password={self.config.get('postgres_password')}
|
||||||
|
host={self.config.get('postgres_host')}
|
||||||
|
port={self.config.get('postgres_port')}""")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Fail to connect %s "%self.config.get("pgdb_host") + str(e))
|
logging.error(
|
||||||
|
"Fail to connect %s " %
|
||||||
|
self.config.get("pgdb_host") + str(e))
|
||||||
|
|
||||||
def __close__(self):
|
def __close__(self):
|
||||||
try:
|
try:
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Fail to close %s "%self.config.get("pgdb_host") + str(e))
|
logging.error(
|
||||||
|
"Fail to close %s " %
|
||||||
|
self.config.get("pgdb_host") + str(e))
|
||||||
|
|
||||||
def select(self, sql):
|
def select(self, sql):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
return pd.read_sql(sql, self.conn)
|
return pd.read_sql(sql, self.conn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Fail to exec {sql} "+str(e))
|
logging.error(f"Fail to exec {sql} " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
|
||||||
def update(self, sql):
|
def update(self, sql):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
@ -53,11 +60,11 @@ class Postgres(object):
|
|||||||
cur.close()
|
cur.close()
|
||||||
return updated_rows
|
return updated_rows
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Fail to exec {sql} "+str(e))
|
logging.error(f"Fail to exec {sql} " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Postgres("infiniflow", "docgpt")
|
Postgres("infiniflow", "docgpt")
|
||||||
|
|
||||||
|
@ -228,7 +228,8 @@ class HuEs:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
def search(self, q, idxnm=None, src=False, timeout="2s"):
|
||||||
if not isinstance(q, dict): q = Search().query(q).to_dict()
|
if not isinstance(q, dict):
|
||||||
|
q = Search().query(q).to_dict()
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
res = self.es.search(index=(self.idxnm if not idxnm else idxnm),
|
||||||
@ -274,9 +275,10 @@ class HuEs:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
def updateScriptByQuery(self, q, scripts, idxnm=None):
|
||||||
ubq = UpdateByQuery(index=self.idxnm if not idxnm else idxnm).using(self.es).query(q)
|
ubq = UpdateByQuery(
|
||||||
|
index=self.idxnm if not idxnm else idxnm).using(
|
||||||
|
self.es).query(q)
|
||||||
ubq = ubq.script(source=scripts)
|
ubq = ubq.script(source=scripts)
|
||||||
ubq = ubq.params(refresh=True)
|
ubq = ubq.params(refresh=True)
|
||||||
ubq = ubq.params(slices=5)
|
ubq = ubq.params(slices=5)
|
||||||
@ -294,7 +296,6 @@ class HuEs:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def deleteByQuery(self, query, idxnm=""):
|
def deleteByQuery(self, query, idxnm=""):
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
try:
|
try:
|
||||||
@ -392,7 +393,7 @@ class HuEs:
|
|||||||
return rr
|
return rr
|
||||||
|
|
||||||
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
def scrollIter(self, pagesize=100, scroll_time='2m', q={
|
||||||
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
"query": {"match_all": {}}, "sort": [{"updated_at": {"order": "desc"}}]}):
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
try:
|
try:
|
||||||
page = self.es.search(
|
page = self.es.search(
|
||||||
|
@ -4,6 +4,7 @@ from util import config
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
class HuMinio(object):
|
class HuMinio(object):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
self.config = config.init(env)
|
self.config = config.init(env)
|
||||||
@ -12,64 +13,62 @@ class HuMinio(object):
|
|||||||
|
|
||||||
def __open__(self):
|
def __open__(self):
|
||||||
try:
|
try:
|
||||||
if self.conn:self.__close__()
|
if self.conn:
|
||||||
|
self.__close__()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = Minio(self.config.get("minio_host"),
|
self.conn = Minio(self.config.get("minio_host"),
|
||||||
access_key=self.config.get("minio_usr"),
|
access_key=self.config.get("minio_user"),
|
||||||
secret_key=self.config.get("minio_pwd"),
|
secret_key=self.config.get("minio_password"),
|
||||||
secure=False
|
secure=False
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Fail to connect %s "%self.config.get("minio_host") + str(e))
|
logging.error(
|
||||||
|
"Fail to connect %s " %
|
||||||
|
self.config.get("minio_host") + str(e))
|
||||||
|
|
||||||
def __close__(self):
|
def __close__(self):
|
||||||
del self.conn
|
del self.conn
|
||||||
self.conn = None
|
self.conn = None
|
||||||
|
|
||||||
|
|
||||||
def put(self, bucket, fnm, binary):
|
def put(self, bucket, fnm, binary):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
if not self.conn.bucket_exists(bucket):
|
if not self.conn.bucket_exists(bucket):
|
||||||
self.conn.make_bucket(bucket)
|
self.conn.make_bucket(bucket)
|
||||||
|
|
||||||
r = self.conn.put_object(bucket, fnm,
|
r = self.conn.put_object(bucket, fnm,
|
||||||
BytesIO(binary),
|
BytesIO(binary),
|
||||||
len(binary)
|
len(binary)
|
||||||
)
|
)
|
||||||
return r
|
return r
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Fail put {bucket}/{fnm}: "+str(e))
|
logging.error(f"Fail put {bucket}/{fnm}: " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
def get(self, bucket, fnm):
|
def get(self, bucket, fnm):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
r = self.conn.get_object(bucket, fnm)
|
r = self.conn.get_object(bucket, fnm)
|
||||||
return r.read()
|
return r.read()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"fail get {bucket}/{fnm}: "+str(e))
|
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def get_presigned_url(self, bucket, fnm, expires):
|
def get_presigned_url(self, bucket, fnm, expires):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
try:
|
try:
|
||||||
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
return self.conn.get_presigned_url("GET", bucket, fnm, expires)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"fail get {bucket}/{fnm}: "+str(e))
|
logging.error(f"fail get {bucket}/{fnm}: " + str(e))
|
||||||
self.__open__()
|
self.__open__()
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -78,9 +77,8 @@ if __name__ == "__main__":
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
img = Image.open(fnm)
|
img = Image.open(fnm)
|
||||||
buff = BytesIO()
|
buff = BytesIO()
|
||||||
img.save(buff, format='JPEG')
|
img.save(buff, format='JPEG')
|
||||||
print(conn.put("test", "11-408.jpg", buff.getvalue()))
|
print(conn.put("test", "11-408.jpg", buff.getvalue()))
|
||||||
bts = conn.get("test", "11-408.jpg")
|
bts = conn.get("test", "11-408.jpg")
|
||||||
img = Image.open(BytesIO(bts))
|
img = Image.open(BytesIO(bts))
|
||||||
img.save("test.jpg")
|
img.save("test.jpg")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user