use onnx models, new deepdoc (#68)

This commit is contained in:
KevinHuSh 2024-02-21 16:32:38 +08:00 committed by GitHub
parent 8c4ec9955e
commit cacd36c5e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 8730 additions and 136 deletions

View File

@ -198,7 +198,7 @@ def chat(dialog, messages, **kwargs):
return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting[dialog.llm_setting_type]
gen_conf = dialog.llm_setting
msg = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
used_token_count, msg = message_fit_in(msg, int(llm.max_tokens * 0.97))
if "max_tokens" in gen_conf:

View File

@ -33,38 +33,17 @@ def set_dialog():
name = req.get("name", "New Dialog")
description = req.get("description", "A helpful Dialog")
language = req.get("language", "Chinese")
llm_setting_type = req.get("llm_setting_type", "Precise")
top_n = req.get("top_n", 6)
similarity_threshold = req.get("similarity_threshold", 0.1)
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
llm_setting = req.get("llm_setting", {
"Creative": {
"temperature": 0.9,
"top_p": 0.9,
"frequency_penalty": 0.2,
"presence_penalty": 0.4,
"max_tokens": 512
},
"Precise": {
"temperature": 0.1,
"top_p": 0.3,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
"max_tokens": 215
},
"Evenly": {
"temperature": 0.5,
"top_p": 0.5,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
"max_tokens": 215
},
"Custom": {
"temperature": 0.2,
"top_p": 0.3,
"frequency_penalty": 0.6,
"presence_penalty": 0.3,
"max_tokens": 215
},
"temperature": 0.1,
"top_p": 0.3,
"frequency_penalty": 0.7,
"presence_penalty": 0.4,
"max_tokens": 215
})
prompt_config = req.get("prompt_config", {
default_prompt = {
"system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
以下是知识库
{knowledge}
@ -74,30 +53,40 @@ def set_dialog():
{"key": "knowledge", "optional": False}
],
"empty_response": "Sorry! 知识库中未找到相关内容!"
})
}
prompt_config = req.get("prompt_config", default_prompt)
if len(prompt_config["parameters"]) < 1:
return get_data_error_result(retmsg="'knowledge' should be in parameters")
if not prompt_config["system"]: prompt_config["system"] = default_prompt["system"]
# if len(prompt_config["parameters"]) < 1:
# prompt_config["parameters"] = default_prompt["parameters"]
# for p in prompt_config["parameters"]:
# if p["key"] == "knowledge":break
# else: prompt_config["parameters"].append(default_prompt["parameters"][0])
for p in prompt_config["parameters"]:
if prompt_config["system"].find("{%s}"%p["key"]) < 0:
if p["optional"]: continue
if prompt_config["system"].find("{%s}" % p["key"]) < 0:
return get_data_error_result(retmsg="Parameter '{}' is not used".format(p["key"]))
try:
e, tenant = TenantService.get_by_id(current_user.id)
if not e:return get_data_error_result(retmsg="Tenant not found!")
if not e: return get_data_error_result(retmsg="Tenant not found!")
llm_id = req.get("llm_id", tenant.llm_id)
if not dialog_id:
if not req.get("kb_ids"):return get_data_error_result(retmsg="Fail! Please select knowledgebase!")
dia = {
"id": get_uuid(),
"tenant_id": current_user.id,
"name": name,
"kb_ids": req["kb_ids"],
"description": description,
"language": language,
"llm_id": llm_id,
"llm_setting_type": llm_setting_type,
"llm_setting": llm_setting,
"prompt_config": prompt_config
"prompt_config": prompt_config,
"top_n": top_n,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight
}
if not DialogService.save(**dia): return get_data_error_result(retmsg="Fail to new a dialog!")
e, dia = DialogService.get_by_id(dia["id"])
@ -122,7 +111,7 @@ def set_dialog():
def get():
dialog_id = request.args["dialog_id"]
try:
e,dia = DialogService.get_by_id(dialog_id)
e, dia = DialogService.get_by_id(dialog_id)
if not e: return get_data_error_result(retmsg="Dialog not found!")
dia = dia.to_dict()
dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"])
@ -130,20 +119,22 @@ def get():
except Exception as e:
return server_error_response(e)
def get_kb_names(kb_ids):
ids, nms = [], []
for kid in kb_ids:
e, kb = KnowledgebaseService.get_by_id(kid)
if not e or kb.status != StatusEnum.VALID.value:continue
if not e or kb.status != StatusEnum.VALID.value: continue
ids.append(kid)
nms.append(kb.name)
return ids, nms
@manager.route('/list', methods=['GET'])
@login_required
def list():
try:
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value)
diags = DialogService.query(tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time)
diags = [d.to_dict() for d in diags]
for d in diags:
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
@ -154,12 +145,11 @@ def list():
@manager.route('/rm', methods=['POST'])
@login_required
@validate_request("dialog_id")
@validate_request("dialog_ids")
def rm():
req = request.json
try:
if not DialogService.update_by_id(req["dialog_id"], {"status": StatusEnum.INVALID.value}):
return get_data_error_result(retmsg="Dialog not found!")
DialogService.update_many_by_id([{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
return get_json_result(data=True)
except Exception as e:
return server_error_response(e)
return server_error_response(e)

View File

@ -529,8 +529,6 @@ class Dialog(DataBaseModel):
icon = CharField(max_length=16, null=False, help_text="dialog icon")
language = CharField(max_length=32, null=True, default="Chinese", help_text="English|Chinese")
llm_id = CharField(max_length=32, null=False, help_text="default llm ID")
llm_setting_type = CharField(max_length=8, null=False, help_text="Creative|Precise|Evenly|Custom",
default="Creative")
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
"presence_penalty": 0.4, "max_tokens": 215})
prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced")

0
deepdoc/__init__.py Normal file
View File

View File

@ -1,4 +1,3 @@
import copy
import random
from .pdf_parser import HuParser as PdfParser
@ -10,7 +9,7 @@ import re
from nltk import word_tokenize
from rag.nlp import stemmer, huqie
from ..utils import num_tokens_from_string
from rag.utils import num_tokens_from_string
BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
import os
import random
from functools import partial
import fitz
import requests
@ -15,6 +14,7 @@ from PIL import Image
import numpy as np
from api.db import ParserType
from deepdoc.visual import OCR, Recognizer
from rag.nlp import huqie
from collections import Counter
from copy import deepcopy
@ -26,13 +26,32 @@ logging.getLogger("pdfminer").setLevel(logging.WARNING)
class HuParser:
def __init__(self):
from paddleocr import PaddleOCR
logging.getLogger("ppocr").setLevel(logging.ERROR)
self.ocr = PaddleOCR(use_angle_cls=False, lang="ch")
self.ocr = OCR()
if not hasattr(self, "model_speciess"):
self.model_speciess = ParserType.GENERAL.value
self.layouter = partial(self.__remote_call, self.model_speciess)
self.tbl_det = partial(self.__remote_call, "table_component")
self.layout_labels = [
"_background_",
"Text",
"Title",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Header",
"Footer",
"Reference",
"Equation",
]
self.tsr_labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
self.layouter = Recognizer(self.layout_labels, "layout", "/data/newpeak/medical-gpt/res/ppdet/")
self.tbl_det = Recognizer(self.tsr_labels, "tsr", "/data/newpeak/medical-gpt/res/ppdet.tbl/")
self.updown_cnt_mdl = xgb.Booster()
if torch.cuda.is_available():
@ -56,7 +75,7 @@ class HuParser:
token = os.environ.get("INFINIFLOW_TOKEN")
if not url or not token:
logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
return []
return [[] for _ in range(len(images))]
def convert_image_to_bytes(PILimage):
image = BytesIO()
@ -382,7 +401,7 @@ class HuParser:
return layouts
def __table_paddle(self, images):
def __table_tsr(self, images):
tbls = self.tbl_det(images, thr=0.5)
res = []
# align left&right for rows, align top&bottom for columns
@ -452,7 +471,7 @@ class HuParser:
assert len(self.page_images) == len(tbcnt) - 1
if not imgs:
return
recos = self.__table_paddle(imgs)
recos = self.__table_tsr(imgs)
tbcnt = np.cumsum(tbcnt)
for i in range(len(tbcnt) - 1): # for page
pg = []
@ -517,8 +536,8 @@ class HuParser:
b["H_right"] = spans[ii]["x1"]
b["SP"] = ii
def __ocr_paddle(self, pagenum, img, chars, ZM=3):
bxs = self.ocr.ocr(np.array(img), cls=True)[0]
def __ocr(self, pagenum, img, chars, ZM=3):
bxs = self.ocr(np.array(img))
if not bxs:
self.boxes.append([])
return
@ -557,11 +576,12 @@ class HuParser:
self.boxes.append(bxs)
def _layouts_paddle(self, ZM):
def _layouts_rec(self, ZM):
assert len(self.page_images) == len(self.boxes)
# Tag layout type
boxes = []
layouts = self.layouter(self.page_images)
#save_results(self.page_images, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
assert len(self.page_images) == len(layouts)
for pn, lts in enumerate(layouts):
bxs = self.boxes[pn]
@ -1741,7 +1761,7 @@ class HuParser:
# else:
# self.page_cum_height.append(
# np.max([c["bottom"] for c in chars]))
self.__ocr_paddle(i + 1, img, chars, zoomin)
self.__ocr(i + 1, img, chars, zoomin)
if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
bxes = [b for bxs in self.boxes for b in bxs]
@ -1754,7 +1774,7 @@ class HuParser:
def __call__(self, fnm, need_image=True, zoomin=3, return_html=False):
self.__images__(fnm, zoomin)
self._layouts_paddle(zoomin)
self._layouts_rec(zoomin)
self._table_transformer_job(zoomin)
self._text_merge()
self._concat_downward()

View File

@ -0,0 +1,2 @@
from .ocr import OCR
from .recognizer import Recognizer

561
deepdoc/visual/ocr.py Normal file
View File

@ -0,0 +1,561 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import time
import os
from huggingface_hub import snapshot_download
from .operators import *
import numpy as np
import onnxruntime as ort
from api.utils.file_utils import get_project_base_directory
from .postprocess import build_post_process
from rag.settings import cron_logger
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
def load_model(model_dir, nm):
model_file_path = os.path.join(model_dir, nm + ".onnx")
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
sess = ort.InferenceSession(model_file_path)
return sess, sess.get_inputs()[0]
class TextRecognizer(object):
def __init__(self, model_dir):
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor = load_model(model_dir, 'rec')
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
w = self.input_tensor.shape[3:][0]
if isinstance(w, str):
pass
elif w is not None and w > 0:
imgW = w
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_vl(self, img, image_shape):
imgC, imgH, imgW = image_shape
img = img[:, :, ::-1] # bgr2rgb
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
return resized_image
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def resize_norm_img_sar(self, img, image_shape,
width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
valid_ratio = 1.0
# make sure new_width is an integral multiple of width_divisor.
width_divisor = int(1 / width_downsample_ratio)
# resize
ratio = w / float(h)
resize_w = math.ceil(imgH * ratio)
if resize_w % width_divisor != 0:
resize_w = round(resize_w / width_divisor) * width_divisor
if imgW_min is not None:
resize_w = max(imgW_min, resize_w)
if imgW_max is not None:
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
resize_w = min(imgW_max, resize_w)
resized_image = cv2.resize(img, (resize_w, imgH))
resized_image = resized_image.astype('float32')
# norm
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
resize_shape = resized_image.shape
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
padding_im[:, :, 0:resize_w] = resized_image
pad_shape = padding_im.shape
return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_abinet(self, img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image / 255.
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
resized_image = (
resized_image - mean[None, None, ...]) / std[None, None, ...]
resized_image = resized_image.transpose((2, 0, 1))
resized_image = resized_image.astype('float32')
return resized_image
def norm_img_can(self, img, image_shape):
img = cv2.cvtColor(
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
if self.rec_image_shape[0] == 1:
h, w = img.shape
_, imgH, imgW = self.rec_image_shape
if h < imgH or w < imgW:
padding_h = max(imgH - h, 0)
padding_w = max(imgW - w, 0)
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
'constant',
constant_values=(255))
img = img_padded
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
img = img.astype('float32')
return img
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
st = time.time()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(None, input_dict)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
return rec_res, time.time() - st
class TextDetector(object):
def __init__(self, model_dir):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.6, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor = load_model(model_dir, 'det')
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
'DetResizeForTest': {
'image_shape': [img_h, img_w]
}
}
self.preprocess_op = create_operators(pre_process_list)
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
st = time.time()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
input_dict = {}
input_dict[self.input_tensor.name] = img
outputs = self.predictor.run(None, input_dict)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes, time.time() - st
class OCR(object):
def __init__(self, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
self.drop_score = 0.5
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def __call__(self, img, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None:
return None, None, time_dict
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
if dt_boxes is None:
end = time.time()
time_dict['all'] = end - start
return None, None, time_dict
else:
cron_logger.debug("dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), elapse))
img_crop_list = []
dt_boxes = self.sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
cron_logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse))
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
end = time.time()
time_dict['all'] = end - start
#for bno in range(len(img_crop_list)):
# print(f"{bno}, {rec_res[bno]}")
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))

6623
deepdoc/visual/ocr.res Normal file

File diff suppressed because it is too large Load Diff

710
deepdoc/visual/operators.py Normal file
View File

@ -0,0 +1,710 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import six
import cv2
import numpy as np
import math
from PIL import Image
class DecodeImage(object):
""" decode image """
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
if six.PY2:
assert isinstance(img, str) and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert isinstance(img, bytes) and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class StandardizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
class Pad(object):
def __init__(self, size=None, size_div=32, **kwargs):
if size is not None and not isinstance(size, (int, list, tuple)):
raise TypeError("Type of target_size is invalid. Now is {}".format(
type(size)))
if isinstance(size, int):
size = [size, size]
self.size = size
self.size_div = size_div
def __call__(self, data):
img = data['image']
img_h, img_w = img.shape[0], img.shape[1]
if self.size:
resize_h2, resize_w2 = self.size
assert (
img_h < resize_h2 and img_w < resize_w2
), '(h, w) of target size should be greater than (img_h, img_w)'
else:
resize_h2 = max(
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
self.size_div)
resize_w2 = max(
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
self.size_div)
img = cv2.copyMakeBorder(
img,
0,
resize_h2 - img_h,
0,
resize_w2 - img_w,
cv2.BORDER_CONSTANT,
value=0)
data['image'] = img
return data
class LinearResize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class Resize(object):
def __init__(self, size=(640, 640), **kwargs):
self.size = size
def resize_image(self, img):
resize_h, resize_w = self.size
ori_h, ori_w = img.shape[:2] # (h, w, c)
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
return img, [ratio_h, ratio_w]
def __call__(self, data):
img = data['image']
if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
if 'polys' in data:
new_boxes = []
for box in text_polys:
new_box = []
for cord in box:
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
return data
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
elif 'resize_long' in kwargs:
self.resize_type = 2
self.resize_long = kwargs.get('resize_long', 960)
else:
self.limit_side_len = 736
self.limit_type = 'min'
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if sum([src_h, src_w]) < 64:
img = self.image_padding(img)
if self.resize_type == 0:
# img, shape = self.resize_image_type0(img)
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
elif self.resize_type == 2:
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
else:
# img, shape = self.resize_image_type1(img)
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
data['image'] = img
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def image_padding(self, im, value=0):
h, w, c = im.shape
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
im_pad[:h, :w, :] = im
return im_pad
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
# return img, np.array([ori_h, ori_w])
return img, [ratio_h, ratio_w]
def resize_image_type0(self, img):
"""
resize image to a size multiple of 32 which is required by the network
args:
img(array): array with shape [h, w, c]
return(tuple):
img, (ratio_h, ratio_w)
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if max(h, w) > limit_side_len:
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'min':
if min(h, w) < limit_side_len:
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
else:
ratio = 1.
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
img = cv2.resize(img, (int(resize_w), int(resize_h)))
except BaseException:
print(img.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
def resize_image_type2(self, img):
h, w, _ = img.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h
else:
ratio = float(self.resize_long) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
img = cv2.resize(img, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
class KieResize(object):
def __init__(self, **kwargs):
super(KieResize, self).__init__()
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
'img_scale'][1]
def __call__(self, data):
img = data['image']
points = data['points']
src_h, src_w, _ = img.shape
im_resized, scale_factor, [ratio_h, ratio_w
], [new_h, new_w] = self.resize_image(img)
resize_points = self.resize_boxes(img, points, scale_factor)
data['ori_image'] = img
data['ori_boxes'] = points
data['points'] = resize_points
data['image'] = im_resized
data['shape'] = np.array([new_h, new_w])
return data
def resize_image(self, img):
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
scale = [512, 1024]
h, w = img.shape[:2]
max_long_edge = max(scale)
max_short_edge = min(scale)
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
scale_factor) + 0.5)
max_stride = 32
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(img, (resize_w, resize_h))
new_h, new_w = im.shape[:2]
w_scale = new_w / w
h_scale = new_h / h
scale_factor = np.array(
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
norm_img[:new_h, :new_w, :] = im
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
def resize_boxes(self, im, points, scale_factor):
points = points * scale_factor
img_shape = im.shape[:2]
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
return points
class SRResize(object):
def __init__(self,
imgH=32,
imgW=128,
down_sample_scale=4,
keep_ratio=False,
min_ratio=1,
mask=False,
infer_mode=False,
**kwargs):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
self.down_sample_scale = down_sample_scale
self.mask = mask
self.infer_mode = infer_mode
def __call__(self, data):
imgH = self.imgH
imgW = self.imgW
images_lr = data["image_lr"]
transform2 = ResizeNormalize(
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
images_lr = transform2(images_lr)
data["img_lr"] = images_lr
if self.infer_mode:
return data
images_HR = data["image_hr"]
label_strs = data["label"]
transform = ResizeNormalize((imgW, imgH))
images_HR = transform(images_HR)
data["img_hr"] = images_HR
return data
class ResizeNormalize(object):
def __init__(self, size, interpolation=Image.BICUBIC):
self.size = size
self.interpolation = interpolation
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
class GrayImageChannelFormat(object):
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
data['image'] = np.abs(img_expanded - 1)
else:
data['image'] = img_expanded
data['src_image'] = img
return data
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info

View File

@ -0,0 +1,354 @@
import copy
import numpy as np
import cv2
import paddle
from shapely.geometry import Polygon
import pyclipper
def build_post_process(config, global_config=None):
support_dict = ['DBPostProcess', 'CTCLabelDecode']
config = copy.deepcopy(config)
module_name = config.pop('name')
if module_name == "None":
return
if global_config is not None:
config.update(global_config)
assert module_name in support_dict, Exception(
'post process only support {}'.format(support_dict))
module_class = eval(module_name)(**config)
return module_class
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self,
thresh=0.3,
box_thresh=0.7,
max_candidates=1000,
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
box_type='quad',
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.box_type = box_type
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
cv2.CHAIN_APPROX_SIMPLE)
if len(outs) == 3:
img, contours, _ = outs[0], outs[1], outs[2]
elif len(outs) == 2:
contours, _ = outs[0], outs[1]
num_contours = min(len(contours), self.max_candidates)
boxes = []
scores = []
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score:
continue
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.astype("int32"))
scores.append(score)
return np.array(boxes, dtype="int32"), scores
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
if self.dilation_kernel is not None:
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
else:
mask = segmentation[batch_index]
if self.box_type == 'poly':
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
elif self.box_type == 'quad':
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
else:
raise ValueError(
"box_type can only be one of ['quad', 'poly']")
boxes_batch.append({'points': boxes})
return boxes_batch
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
self.reverse = False
self.character_str = []
if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
self.character_str.append(line)
if use_space_char:
self.character_str.append(" ")
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def pred_reverse(self, pred):
pred_re = []
c_current = ''
for c in pred:
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
if c_current != '':
pred_re.append(c_current)
pred_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
pred_re.append(c_current)
return ''.join(pred_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[
batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id]
for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]
text = ''.join(char_list)
if self.reverse: # for arabic rec
text = self.pred_reverse(text)
result_list.append((text, np.mean(conf_list).tolist()))
return result_list
def get_ignored_tokens(self):
return [0] # for ctc blank
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
if label is None:
return text
label = self.decode(label)
return text, label
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character

View File

@ -0,0 +1,139 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import onnxruntime as ort
from huggingface_hub import snapshot_download
from .operators import *
from rag.settings import cron_logger
class Recognizer(object):
def __init__(self, label_list, task_name, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
if ort.get_device() == "GPU":
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
else:
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
self.label_list = label_list
def create_inputs(self, imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
im_shape = []
scale_factor = []
if len(imgs) == 1:
inputs['image'] = np.array((imgs[0],)).astype('float32')
inputs['im_shape'] = np.array(
(im_info[0]['im_shape'],)).astype('float32')
inputs['scale_factor'] = np.array(
(im_info[0]['scale_factor'],)).astype('float32')
return inputs
for e in im_info:
im_shape.append(np.array((e['im_shape'],)).astype('float32'))
scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
inputs['image'] = np.stack(padding_imgs, axis=0)
return inputs
def preprocess(self, image_list):
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
inputs = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
return inputs
def __call__(self, image_list, thr=0.7, batch_size=16):
res = []
imgs = []
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
imgs.append(np.array(image_list[i]))
else: imgs.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(imgs))
batch_image_list = imgs[start_index:end_index]
inputs = self.preprocess(batch_image_list)
for ins in inputs:
bb = []
for b in self.ort_sess.run(None, ins)[0]:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res

83
deepdoc/visual/seeit.py Normal file
View File

@ -0,0 +1,83 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import PIL
from PIL import ImageDraw
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for idx, im in enumerate(image_list):
im = draw_box(im, results[idx], labels, threshold=threshold)
out_path = os.path.join(output_dir, f"{idx}.jpg")
im.save(out_path, quality=95)
print("save result to: " + out_path)
def draw_box(im, result, lables, threshold=0.5):
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
color_list = get_color_map_list(len(lables))
clsid2color = {n.lower():color_list[i] for i,n in enumerate(lables)}
result = [r for r in result if r["score"] >= threshold]
for dt in result:
color = tuple(clsid2color[dt["type"]])
xmin, ymin, xmax, ymax = dt["bbox"]
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=color)
# draw label
text = "{} {:.4f}".format(dt["type"], dt["score"])
tw, th = imagedraw_textsize_c(draw, text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def get_color_map_list(num_classes):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def imagedraw_textsize_c(draw, text):
if int(PIL.__version__.split('.')[0]) < 10:
tw, th = draw.textsize(text)
else:
left, top, right, bottom = draw.textbbox((0, 0), text)
tw, th = right - left, bottom - top
return tw, th

View File

@ -1,15 +1,24 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import random
import re
import numpy as np
from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \
from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices
from rag.nlp import huqie
from rag.parser.docx_parser import HuDocxParser
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import PdfParser, DocxParser
class Pdf(HuParser):
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
@ -21,7 +30,7 @@ class Pdf(HuParser):
from timeit import default_timer as timer
start = timer()
self._layouts_paddle(zoomin)
self._layouts_rec(zoomin)
callback(0.47, "Layout analysis finished")
print("paddle layouts:", timer() - start)
self._table_transformer_job(zoomin)
@ -53,7 +62,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
sections,tbls = [], []
if re.search(r"\.docx?$", filename, re.IGNORECASE):
callback(0.1, "Start to parse.")
doc_parser = HuDocxParser()
doc_parser = DocxParser()
# TODO: table of contents need to be removed
sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200)))

View File

@ -1,16 +1,27 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from io import BytesIO
from docx import Document
from rag.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
make_colon_as_title
from rag.nlp import huqie
from rag.parser.docx_parser import HuDocxParser
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import PdfParser, DocxParser
from rag.settings import cron_logger
class Docx(HuDocxParser):
class Docx(DocxParser):
def __init__(self):
pass
@ -35,7 +46,7 @@ class Docx(HuDocxParser):
return [l for l in lines if l]
class Pdf(HuParser):
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
@ -47,7 +58,7 @@ class Pdf(HuParser):
from timeit import default_timer as timer
start = timer()
self._layouts_paddle(zoomin)
self._layouts_rec(zoomin)
callback(0.77, "Layout analysis finished")
cron_logger.info("paddle layouts:".format((timer()-start)/(self.total_page+0.1)))
self._naive_vertical_merge()

View File

@ -1,12 +1,12 @@
import copy
import re
from rag.parser import tokenize
from deepdoc.parser import tokenize
from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import PdfParser
from rag.utils import num_tokens_from_string
class Pdf(HuParser):
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
@ -18,7 +18,7 @@ class Pdf(HuParser):
from timeit import default_timer as timer
start = timer()
self._layouts_paddle(zoomin)
self._layouts_rec(zoomin)
callback(0.5, "Layout analysis finished.")
print("paddle layouts:", timer() - start)
self._table_transformer_job(zoomin)

View File

@ -1,13 +1,25 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from rag.app import laws
from rag.parser import is_english, tokenize, naive_merge
from deepdoc.parser import is_english, tokenize, naive_merge
from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import PdfParser
from rag.settings import cron_logger
class Pdf(HuParser):
class Pdf(PdfParser):
def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None):
self.__images__(
@ -19,7 +31,7 @@ class Pdf(HuParser):
from timeit import default_timer as timer
start = timer()
self._layouts_paddle(zoomin)
self._layouts_rec(zoomin)
callback(0.77, "Layout analysis finished")
cron_logger.info("paddle layouts:".format((timer() - start) / (self.total_page + 0.1)))
self._naive_vertical_merge()

View File

@ -1,16 +1,28 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from collections import Counter
from api.db import ParserType
from rag.parser import tokenize
from deepdoc.parser import tokenize
from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import PdfParser
import numpy as np
from rag.utils import num_tokens_from_string
class Pdf(HuParser):
class Pdf(PdfParser):
def __init__(self):
self.model_speciess = ParserType.PAPER.value
super().__init__()
@ -26,7 +38,7 @@ class Pdf(HuParser):
from timeit import default_timer as timer
start = timer()
self._layouts_paddle(zoomin)
self._layouts_rec(zoomin)
callback(0.47, "Layout analysis finished")
print("paddle layouts:", timer() - start)
self._table_transformer_job(zoomin)

View File

@ -1,11 +1,22 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from io import BytesIO
from pptx import Presentation
from rag.parser import tokenize, is_english
from deepdoc.parser import tokenize, is_english
from rag.nlp import huqie
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import PdfParser
class Ppt(object):
@ -58,7 +69,7 @@ class Ppt(object):
return [(txts[i], imgs[i]) for i in range(len(txts))]
class Pdf(HuParser):
class Pdf(PdfParser):
def __init__(self):
super().__init__()
@ -74,7 +85,7 @@ class Pdf(HuParser):
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(len(self.boxes), len(self.page_images))
res = []
#################### More precisely ###################
# self._layouts_paddle(zoomin)
# self._layouts_rec(zoomin)
# self._text_merge()
# pages = {}
# for b in self.boxes:

View File

@ -1,13 +1,25 @@
import random
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from io import BytesIO
from nltk import word_tokenize
from openpyxl import load_workbook
from rag.parser import is_english, random_choices
from deepdoc.parser import is_english, random_choices
from rag.nlp import huqie, stemmer
from deepdoc.parser import ExcelParser
class Excel(object):
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
wb = load_workbook(fnm)

View File

@ -1,59 +1,82 @@
import copy
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import datetime
import json
import os
import re
import pandas as pd
import requests
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.settings import stat_logger
from rag.nlp import huqie
from deepdoc.parser.resume import refactor
from deepdoc.parser.resume import step_one, step_two
from rag.settings import cron_logger
from rag.utils import rmSpace
forbidden_select_fields4resume = [
"name_pinyin_kwd", "edu_first_fea_kwd", "degree_kwd", "sch_rank_kwd", "edu_fea_kwd"
]
def remote_call(filename, binary):
q = {
"header": {
"uid": 1,
"user": "kevinhu",
"log_id": filename
},
"request": {
"p": {
"request_id": "1",
"encrypt_type": "base64",
"filename": filename,
"langtype": '',
"fileori": base64.b64encode(binary.stream.read()).decode('utf-8')
},
"c": "resume_parse_module",
"m": "resume_parse"
}
}
for _ in range(3):
try:
resume = requests.post("http://127.0.0.1:61670/tog", data=json.dumps(q))
resume = resume.json()["response"]["results"]
resume = refactor(resume)
for k in ["education", "work", "project", "training", "skill", "certificate", "language"]:
if not resume.get(k) and k in resume: del resume[k]
resume = step_one.refactor(pd.DataFrame([{"resume_content": json.dumps(resume), "tob_resume_id": "x",
"updated_at": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]))
resume = step_two.parse(resume)
return resume
except Exception as e:
cron_logger.error("Resume parser error: "+str(e))
return {}
def chunk(filename, binary=None, callback=None, **kwargs):
"""
The supported file formats are pdf, docx and txt.
To maximize the effectiveness, parse the resume correctly,
please visit https://github.com/infiniflow/ragflow, and sign in the our demo web-site
to get token. It's FREE!
Set INFINIFLOW_SERVER and INFINIFLOW_TOKEN in '.env' file or
using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN in docker container.
To maximize the effectiveness, parse the resume correctly, please concat us: https://github.com/infiniflow/ragflow
"""
if not re.search(r"\.(pdf|doc|docx|txt)$", filename, flags=re.IGNORECASE):
raise NotImplementedError("file type not supported yet(pdf supported)")
url = os.environ.get("INFINIFLOW_SERVER")
token = os.environ.get("INFINIFLOW_TOKEN")
if not url or not token:
stat_logger.warning(
"INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
return []
if not binary:
with open(filename, "rb") as f:
binary = f.read()
def remote_call():
nonlocal filename, binary
for _ in range(3):
try:
res = requests.post(url + "/v1/layout/resume/", files=[(filename, binary)],
headers={"Authorization": token}, timeout=180)
res = res.json()
if res["retcode"] != 0:
raise RuntimeError(res["retmsg"])
return res["data"]
except RuntimeError as e:
raise e
except Exception as e:
cron_logger.error("resume parsing:" + str(e))
callback(0.2, "Resume parsing is going on...")
resume = remote_call()
resume = remote_call(filename, binary)
if len(resume.keys()) < 7:
callback(-1, "Resume is not successfully parsed.")
return []

View File

@ -1,3 +1,15 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import re
from io import BytesIO
@ -8,11 +20,12 @@ from openpyxl import load_workbook
from dateutil.parser import parse as datetime_parse
from api.db.services.knowledgebase_service import KnowledgebaseService
from rag.parser import is_english, tokenize
from rag.nlp import huqie, stemmer
from deepdoc.parser import is_english, tokenize
from rag.nlp import huqie
from deepdoc.parser import ExcelParser
class Excel(object):
class Excel(ExcelParser):
def __call__(self, fnm, binary=None, callback=None):
if not binary:
wb = load_workbook(fnm)

View File

@ -1,3 +1,15 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import os
import copy
@ -443,13 +455,13 @@ if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(__file__) + "/../")
if sys.argv[1].split(".")[-1].lower() == "pdf":
from parser import PdfParser
from deepdoc.parser import PdfParser
ckr = PdfChunker(PdfParser())
if sys.argv[1].split(".")[-1].lower().find("doc") >= 0:
from parser import DocxParser
from deepdoc.parser import DocxParser
ckr = DocxChunker(DocxParser())
if sys.argv[1].split(".")[-1].lower().find("xlsx") >= 0:
from parser import ExcelParser
from deepdoc.parser import ExcelParser
ckr = ExcelChunker(ExcelParser())
# ckr.html(sys.argv[1])

View File

@ -21,7 +21,7 @@ from datetime import datetime
from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
from rag.parser.pdf_parser import HuParser
from deepdoc.parser import HuParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm