rename vision, add layour and tsr recognizer (#70)

* rename vision, add layour and tsr recognizer

* trivial fixing
This commit is contained in:
KevinHuSh 2024-02-22 19:11:37 +08:00 committed by GitHub
parent 5a0f1d2b84
commit d32322c081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1055 additions and 1039 deletions

View File

@ -34,7 +34,6 @@ from rag.utils import num_tokens_from_string, encoder, rmSpace
@manager.route('/set', methods=['POST'])
@login_required
@validate_request("dialog_id")
def set_conversation():
req = request.json
conv_id = req.get("conversation_id")
@ -145,7 +144,7 @@ def message_fit_in(msg, max_length=4000):
@manager.route('/completion', methods=['POST'])
@login_required
@validate_request("dialog_id", "messages")
@validate_request("conversation_id", "messages")
def completion():
req = request.json
msg = []
@ -154,12 +153,20 @@ def completion():
if m["role"] == "assistant" and not msg: continue
msg.append({"role": m["role"], "content": m["content"]})
try:
e, dia = DialogService.get_by_id(req["dialog_id"])
e, conv = ConversationService.get_by_id(req["conversation_id"])
if not e:
return get_data_error_result(retmsg="Conversation not found!")
conv.message.append(msg[-1])
e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
del req["dialog_id"]
del req["conversation_id"]
del req["messages"]
return get_json_result(data=chat(dia, msg, **req))
ans = chat(dia, msg, **req)
conv.reference.append(ans["reference"])
conv.message.append({"role": "assistant", "content": ans["answer"]})
ConversationService.update_by_id(conv.id, conv.to_dict())
return get_json_result(data=ans)
except Exception as e:
return server_error_response(e)
@ -194,8 +201,8 @@ def chat(dialog, messages, **kwargs):
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
if not knowledges and prompt_config["empty_response"]:
return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
if not knowledges and prompt_config.get("empty_response"):
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting
@ -205,7 +212,8 @@ def chat(dialog, messages, **kwargs):
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
answer = retrievaler.insert_citations(answer,
if knowledges:
answer = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
@ -213,7 +221,7 @@ def chat(dialog, messages, **kwargs):
vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]:
if c.get("vector"): del c["vector"]
return {"answer": answer, "retrieval": kbinfos}
return {"answer": answer, "reference": kbinfos}
def use_sql(question, field_map, tenant_id, chat_mdl):

View File

@ -94,11 +94,11 @@ def list():
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
m["available"] = m["llm_name"] in mdlnms
m["available"] = m["fid"] in facts
res = {}
for m in llms:

View File

@ -500,7 +500,7 @@ class Document(DataBaseModel):
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=512, null=True, help_text="process message", default="")
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
@ -518,7 +518,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
class Dialog(DataBaseModel):
@ -561,6 +561,7 @@ class Conversation(DataBaseModel):
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name")
message = JSONField(null=True)
reference = JSONField(null=True, default=[])
class Meta:
db_table = "conversation"

View File

@ -75,7 +75,7 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config:
raise LookupError("Model({}) not found".format(mdlnm))
raise LookupError("Model({}) not authorized".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,4 @@
from .ocr import OCR
from .recognizer import Recognizer
from .layout_recognizer import LayoutRecognizer
from .table_structure_recognizer import TableStructureRecognizer

View File

@ -0,0 +1,119 @@
import os
import re
from collections import Counter
from copy import deepcopy
import numpy as np
from api.utils.file_utils import get_project_base_directory
from .recognizer import Recognizer
class LayoutRecognizer(Recognizer):
def __init__(self, domain):
self.layout_labels = [
"_background_",
"Text",
"Title",
"Figure",
"Figure caption",
"Table",
"Table caption",
"Header",
"Footer",
"Reference",
"Equation",
]
super().__init__(self.layout_labels, domain,
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
def __is_garbage(b):
patt = [r"^•+$", r"(版权归©|免责条款|地址[:])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
"(资料|数据)来源[:]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
"\\(cid *: *[0-9]+ *\\)"
]
return any([re.search(p, b["text"]) for p in patt])
layouts = super().__call__(image_list, thr, batch_size)
# save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
assert len(image_list) == len(ocr_res)
# Tag layout type
boxes = []
assert len(image_list) == len(layouts)
garbages = {}
page_layout = []
for pn, lts in enumerate(layouts):
bxs = ocr_res[pn]
lts = [{"type": b["type"],
"score": float(b["score"]),
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
"page_number": pn,
} for b in lts]
lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
lts = self.layouts_cleanup(bxs, lts)
page_layout.append(lts)
# Tag layout type, layouts are ready
def findLayout(ty):
nonlocal bxs, lts, self
lts_ = [lt for lt in lts if lt["type"] == ty]
i = 0
while i < len(bxs):
if bxs[i].get("layout_type"):
i += 1
continue
if __is_garbage(bxs[i]):
bxs.pop(i)
continue
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
thr=0.4)
if ii is None: # belong to nothing
bxs[i]["layout_type"] = ""
i += 1
continue
lts_[ii]["visited"] = True
if lts_[ii]["type"] in ["footer", "header", "reference"]:
if lts_[ii]["type"] not in garbages:
garbages[lts_[ii]["type"]] = []
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
bxs.pop(i)
continue
bxs[i]["layoutno"] = f"{ty}-{ii}"
bxs[i]["layout_type"] = lts_[ii]["type"]
i += 1
for lt in ["footer", "header", "reference", "figure caption",
"table caption", "title", "text", "table", "figure", "equation"]:
findLayout(lt)
# add box to figure layouts which has not text box
for i, lt in enumerate(
[lt for lt in lts if lt["type"] == "figure"]):
if lt.get("visited"):
continue
lt = deepcopy(lt)
del lt["type"]
lt["text"] = ""
lt["layout_type"] = "figure"
lt["layoutno"] = f"figure-{i}"
bxs.append(lt)
boxes.extend(bxs)
ocr_res = boxes
garbag_set = set()
for k in garbages.keys():
garbages[k] = Counter(garbages[k])
for g, c in garbages[k].items():
if c > 1:
garbag_set.add(g)
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
return ocr_res, page_layout

View File

@ -74,7 +74,7 @@ class TextRecognizer(object):
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
"character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"),
"character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
@ -450,7 +450,7 @@ class OCR(object):
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)

View File

@ -0,0 +1,327 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from copy import deepcopy
import onnxruntime as ort
from huggingface_hub import snapshot_download
from . import seeit
from .operators import *
from rag.settings import cron_logger
class Recognizer(object):
def __init__(self, label_list, task_name, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
if ort.get_device() == "GPU":
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
else:
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
self.label_list = label_list
@staticmethod
def sort_Y_firstly(arr, threashold):
# sort using y1 first and then x1
arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
# restore the order using th
if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
and arr[j + 1]["x0"] < arr[j]["x0"]:
tmp = deepcopy(arr[j])
arr[j] = deepcopy(arr[j + 1])
arr[j + 1] = deepcopy(tmp)
return arr
@staticmethod
def sort_X_firstly(arr, threashold, copy=True):
# sort using y1 first and then x1
arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
# restore the order using th
if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
and arr[j + 1]["top"] < arr[j]["top"]:
tmp = deepcopy(arr[j]) if copy else arr[j]
arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
arr[j + 1] = deepcopy(tmp) if copy else tmp
return arr
@staticmethod
def sort_C_firstly(arr, thr=0):
# sort using y1 first and then x1
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
arr = Recognizer.sort_X_firstly(arr, thr)
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
# restore the order using th
if "C" not in arr[j] or "C" not in arr[j + 1]:
continue
if arr[j + 1]["C"] < arr[j]["C"] \
or (
arr[j + 1]["C"] == arr[j]["C"]
and arr[j + 1]["top"] < arr[j]["top"]
):
tmp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = tmp
return arr
return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
@staticmethod
def sort_R_firstly(arr, thr=0):
# sort using y1 first and then x1
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
arr = Recognizer.sort_Y_firstly(arr, thr)
for i in range(len(arr) - 1):
for j in range(i, -1, -1):
if "R" not in arr[j] or "R" not in arr[j + 1]:
continue
if arr[j + 1]["R"] < arr[j]["R"] \
or (
arr[j + 1]["R"] == arr[j]["R"]
and arr[j + 1]["x0"] < arr[j]["x0"]
):
tmp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = tmp
return arr
@staticmethod
def overlapped_area(a, b, ratio=True):
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
if b["x0"] > x1 or b["x1"] < x0:
return 0
if b["bottom"] < tp or b["top"] > btm:
return 0
x0_ = max(b["x0"], x0)
x1_ = min(b["x1"], x1)
assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
tp, btm, x0, x1, b)
tp_ = max(b["top"], tp)
btm_ = min(b["bottom"], btm)
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
tp, btm, x0, x1, b)
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
x0 != 0 and btm - tp != 0 else 0
if ov > 0 and ratio:
ov /= (x1 - x0) * (btm - tp)
return ov
@staticmethod
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
def notOverlapped(a, b):
return any([a["x1"] < b["x0"],
a["x0"] > b["x1"],
a["bottom"] < b["top"],
a["top"] > b["bottom"]])
i = 0
while i + 1 < len(layouts):
j = i + 1
while j < min(i + far, len(layouts)) \
and (layouts[i].get("type", "") != layouts[j].get("type", "")
or notOverlapped(layouts[i], layouts[j])):
j += 1
if j >= min(i + far, len(layouts)):
i += 1
continue
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
i += 1
continue
if layouts[i].get("score") and layouts[j].get("score"):
if layouts[i]["score"] > layouts[j]["score"]:
layouts.pop(j)
else:
layouts.pop(i)
continue
area_i, area_i_1 = 0, 0
for b in boxes:
if not notOverlapped(b, layouts[i]):
area_i += Recognizer.overlapped_area(b, layouts[i], False)
if not notOverlapped(b, layouts[j]):
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
if area_i > area_i_1:
layouts.pop(j)
else:
layouts.pop(i)
return layouts
def create_inputs(self, imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
im_shape = []
scale_factor = []
if len(imgs) == 1:
inputs['image'] = np.array((imgs[0],)).astype('float32')
inputs['im_shape'] = np.array(
(im_info[0]['im_shape'],)).astype('float32')
inputs['scale_factor'] = np.array(
(im_info[0]['scale_factor'],)).astype('float32')
return inputs
for e in im_info:
im_shape.append(np.array((e['im_shape'],)).astype('float32'))
scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
inputs['image'] = np.stack(padding_imgs, axis=0)
return inputs
@staticmethod
def find_overlapped(box, boxes_sorted_by_y, naive=False):
if not boxes_sorted_by_y:
return
bxs = boxes_sorted_by_y
s, e, ii = 0, len(bxs), 0
while s < e and not naive:
ii = (e + s) // 2
pv = bxs[ii]
if box["bottom"] < pv["top"]:
e = ii
continue
if box["top"] > pv["bottom"]:
s = ii + 1
continue
break
while s < ii:
if box["top"] > bxs[s]["bottom"]:
s += 1
break
while e - 1 > ii:
if box["bottom"] < bxs[e - 1]["top"]:
e -= 1
break
max_overlaped_i, max_overlaped = None, 0
for i in range(s, e):
ov = Recognizer.overlapped_area(bxs[i], box)
if ov <= max_overlaped:
continue
max_overlaped_i = i
max_overlaped = ov
return max_overlaped_i
@staticmethod
def find_overlapped_with_threashold(box, boxes, thr=0.3):
if not boxes:
return
max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
s, e = 0, len(boxes)
for i in range(s, e):
ov = Recognizer.overlapped_area(box, boxes[i])
_ov = Recognizer.overlapped_area(boxes[i], box)
if (ov, _ov) < (max_overlaped, _max_overlaped):
continue
max_overlaped_i = i
max_overlaped = ov
_max_overlaped = _ov
return max_overlaped_i
def preprocess(self, image_list):
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
inputs = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
return inputs
def __call__(self, image_list, thr=0.7, batch_size=16):
res = []
imgs = []
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
imgs.append(np.array(image_list[i]))
else: imgs.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(imgs))
batch_image_list = imgs[start_index:end_index]
inputs = self.preprocess(batch_image_list)
for ins in inputs:
bb = []
for b in self.ort_sess.run(None, ins)[0]:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res

View File

@ -0,0 +1,556 @@
import logging
import os
import re
from collections import Counter
from copy import deepcopy
import numpy as np
from api.utils.file_utils import get_project_base_directory
from rag.nlp import huqie
from .recognizer import Recognizer
class TableStructureRecognizer(Recognizer):
def __init__(self):
self.labels = [
"table",
"table column",
"table row",
"table column header",
"table projected row header",
"table spanning cell",
]
super().__init__(self.labels, "tsr",
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
def __call__(self, images, thr=0.5):
tbls = super().__call__(images, thr)
res = []
# align left&right for rows, align top&bottom for columns
for tbl in tbls:
lts = [{"label": b["type"],
"score": b["score"],
"x0": b["bbox"][0], "x1": b["bbox"][2],
"top": b["bbox"][1], "bottom": b["bbox"][-1]
} for b in tbl]
if not lts:
continue
left = [b["x0"] for b in lts if b["label"].find(
"row") > 0 or b["label"].find("header") > 0]
right = [b["x1"] for b in lts if b["label"].find(
"row") > 0 or b["label"].find("header") > 0]
if not left:
continue
left = np.median(left) if len(left) > 4 else np.min(left)
right = np.median(right) if len(right) > 4 else np.max(right)
for b in lts:
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
if b["x0"] > left:
b["x0"] = left
if b["x1"] < right:
b["x1"] = right
top = [b["top"] for b in lts if b["label"] == "table column"]
bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
if not top:
res.append(lts)
continue
top = np.median(top) if len(top) > 4 else np.min(top)
bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
for b in lts:
if b["label"] == "table column":
if b["top"] > top:
b["top"] = top
if b["bottom"] < bottom:
b["bottom"] = bottom
res.append(lts)
return res
@staticmethod
def is_caption(bx):
patt = [
r"[图表]+[ 0-9:]{2,}"
]
if any([re.match(p, bx["text"].strip()) for p in patt]) \
or bx["layout_type"].find("caption") >= 0:
return True
return False
def __blockType(self, b):
patt = [
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
(r"^(20|19)[0-9]{2}年$", "Dt"),
(r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
(r"^第*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
(r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
("^[0-9.,+%/ -]+$", "Nu"),
(r"^[0-9A-Z/\._~-]+$", "Ca"),
(r"^[A-Z]*[a-z' -]+$", "En"),
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()' -]+$", "NE"),
(r"^.{1}$", "Sg")
]
for p, n in patt:
if re.search(p, b["text"].strip()):
return n
tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
if len(tks) > 3:
if len(tks) < 12:
return "Tx"
else:
return "Lx"
if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
return "Nr"
return "Ot"
def construct_table(self, boxes, is_english=False, html=False):
cap = ""
i = 0
while i < len(boxes):
if self.is_caption(boxes[i]):
cap += boxes[i]["text"]
boxes.pop(i)
i -= 1
i += 1
if not boxes:
return []
for b in boxes:
b["btype"] = self.__blockType(b)
max_type = Counter([b["btype"] for b in boxes]).items()
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
logging.debug("MAXTYPE: " + max_type)
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
rowh = np.min(rowh) if rowh else 0
boxes = self.sort_R_firstly(boxes, rowh / 2)
boxes[0]["rn"] = 0
rows = [[boxes[0]]]
btm = boxes[0]["bottom"]
for b in boxes[1:]:
b["rn"] = len(rows) - 1
lst_r = rows[-1]
if lst_r[-1].get("R", "") != b.get("R", "") \
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
): # new row
btm = b["bottom"]
b["rn"] += 1
rows.append([b])
continue
btm = (btm + b["bottom"]) / 2.
rows[-1].append(b)
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
colwm = np.min(colwm) if colwm else 0
crosspage = len(set([b["page_number"] for b in boxes])) > 1
if crosspage:
boxes = self.sort_X_firstly(boxes, colwm / 2, False)
else:
boxes = self.sort_C_firstly(boxes, colwm / 2)
boxes[0]["cn"] = 0
cols = [[boxes[0]]]
right = boxes[0]["x1"]
for b in boxes[1:]:
b["cn"] = len(cols) - 1
lst_c = cols[-1]
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
"page_number"]) \
or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
right = b["x1"]
b["cn"] += 1
cols.append([b])
continue
right = (right + b["x1"]) / 2.
cols[-1].append(b)
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
for b in boxes:
tbl[b["rn"]][b["cn"]].append(b)
if len(rows) >= 4:
# remove single in column
j = 0
while j < len(tbl[0]):
e, ii = 0, 0
for i in range(len(tbl)):
if tbl[i][j]:
e += 1
ii = i
if e > 1:
break
if e > 1:
j += 1
continue
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
[j - 1][0].get("text")) or j == 0
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
if f and ff:
j += 1
continue
bx = tbl[ii][j][0]
logging.debug("Relocate column single: " + bx["text"])
# j column only has one value
left, right = 100000, 100000
if j > 0 and not f:
for i in range(len(tbl)):
if tbl[i][j - 1]:
left = min(left, np.min(
[bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
if j + 1 < len(tbl[0]) and not ff:
for i in range(len(tbl)):
if tbl[i][j + 1]:
right = min(right, np.min(
[a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
assert left < 100000 or right < 100000
if left < right:
for jj in range(j, len(tbl[0])):
for i in range(len(tbl)):
for a in tbl[i][jj]:
a["cn"] -= 1
if tbl[ii][j - 1]:
tbl[ii][j - 1].extend(tbl[ii][j])
else:
tbl[ii][j - 1] = tbl[ii][j]
for i in range(len(tbl)):
tbl[i].pop(j)
else:
for jj in range(j + 1, len(tbl[0])):
for i in range(len(tbl)):
for a in tbl[i][jj]:
a["cn"] -= 1
if tbl[ii][j + 1]:
tbl[ii][j + 1].extend(tbl[ii][j])
else:
tbl[ii][j + 1] = tbl[ii][j]
for i in range(len(tbl)):
tbl[i].pop(j)
cols.pop(j)
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
len(cols), len(tbl[0]))
if len(cols) >= 4:
# remove single in row
i = 0
while i < len(tbl):
e, jj = 0, 0
for j in range(len(tbl[i])):
if tbl[i][j]:
e += 1
jj = j
if e > 1:
break
if e > 1:
i += 1
continue
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
[jj][0].get("text")) or i == 0
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
[jj][0].get("text")) or i + 1 >= len(tbl)
if f and ff:
i += 1
continue
bx = tbl[i][jj][0]
logging.debug("Relocate row single: " + bx["text"])
# i row only has one value
up, down = 100000, 100000
if i > 0 and not f:
for j in range(len(tbl[i - 1])):
if tbl[i - 1][j]:
up = min(up, np.min(
[bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
if i + 1 < len(tbl) and not ff:
for j in range(len(tbl[i + 1])):
if tbl[i + 1][j]:
down = min(down, np.min(
[a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
assert up < 100000 or down < 100000
if up < down:
for ii in range(i, len(tbl)):
for j in range(len(tbl[ii])):
for a in tbl[ii][j]:
a["rn"] -= 1
if tbl[i - 1][jj]:
tbl[i - 1][jj].extend(tbl[i][jj])
else:
tbl[i - 1][jj] = tbl[i][jj]
tbl.pop(i)
else:
for ii in range(i + 1, len(tbl)):
for j in range(len(tbl[ii])):
for a in tbl[ii][j]:
a["rn"] -= 1
if tbl[i + 1][jj]:
tbl[i + 1][jj].extend(tbl[i][jj])
else:
tbl[i + 1][jj] = tbl[i][jj]
tbl.pop(i)
rows.pop(i)
# which rows are headers
hdset = set([])
for i in range(len(tbl)):
cnt, h = 0, 0
for j, arr in enumerate(tbl[i]):
if not arr:
continue
cnt += 1
if max_type == "Nu" and arr[0]["btype"] == "Nu":
continue
if any([a.get("H") for a in arr]) \
or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
h += 1
if h / cnt > 0.5:
hdset.add(i)
if html:
return [self.__html_table(cap, hdset,
self.__cal_spans(boxes, rows,
cols, tbl, True)
)]
return self.__desc_table(cap, hdset,
self.__cal_spans(boxes, rows, cols, tbl, False),
is_english)
def __html_table(self, cap, hdset, tbl):
# constrcut HTML
html = "<table>"
if cap:
html += f"<caption>{cap}</caption>"
for i in range(len(tbl)):
row = "<tr>"
txts = []
for j, arr in enumerate(tbl[i]):
if arr is None:
continue
if not arr:
row += "<td></td>" if i not in hdset else "<th></th>"
continue
txt = ""
if arr:
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
txt = "".join([c["text"]
for c in self.sort_Y_firstly(arr, h)])
txts.append(txt)
sp = ""
if arr[0].get("colspan"):
sp = "colspan={}".format(arr[0]["colspan"])
if arr[0].get("rowspan"):
sp += " rowspan={}".format(arr[0]["rowspan"])
if i in hdset:
row += f"<th {sp} >" + txt + "</th>"
else:
row += f"<td {sp} >" + txt + "</td>"
if i in hdset:
if all([t in hdset for t in txts]):
continue
for t in txts:
hdset.add(t)
if row != "<tr>":
row += "</tr>"
else:
row = ""
html += "\n" + row
html += "\n</table>"
return html
def __desc_table(self, cap, hdr_rowno, tbl, is_english):
# get text of every colomn in header row to become header text
clmno = len(tbl[0])
rowno = len(tbl)
headers = {}
hdrset = set()
lst_hdr = []
de = "" if not is_english else " for "
for r in sorted(list(hdr_rowno)):
headers[r] = ["" for _ in range(clmno)]
for i in range(clmno):
if not tbl[r][i]:
continue
txt = "".join([a["text"].strip() for a in tbl[r][i]])
headers[r][i] = txt
hdrset.add(txt)
if all([not t for t in headers[r]]):
del headers[r]
hdr_rowno.remove(r)
continue
for j in range(clmno):
if headers[r][j]:
continue
if j >= len(lst_hdr):
break
headers[r][j] = lst_hdr[j]
lst_hdr = headers[r]
for i in range(rowno):
if i not in hdr_rowno:
continue
for j in range(i + 1, rowno):
if j not in hdr_rowno:
break
for k in range(clmno):
if not headers[j - 1][k]:
continue
if headers[j][k].find(headers[j - 1][k]) >= 0:
continue
if len(headers[j][k]) > len(headers[j - 1][k]):
headers[j][k] += (de if headers[j][k]
else "") + headers[j - 1][k]
else:
headers[j][k] = headers[j - 1][k] \
+ (de if headers[j - 1][k] else "") \
+ headers[j][k]
logging.debug(
f">>>>>>>>>>>>>>>>>{cap}SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
row_txt = []
for i in range(rowno):
if i in hdr_rowno:
continue
rtxt = []
def append(delimer):
nonlocal rtxt, row_txt
rtxt = delimer.join(rtxt)
if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
row_txt[-1] += "\n" + rtxt
else:
row_txt.append(rtxt)
r = 0
if len(headers.items()):
_arr = [(i - r, r) for r, _ in headers.items() if r < i]
if _arr:
_, r = min(_arr, key=lambda x: x[0])
if r not in headers and clmno <= 2:
for j in range(clmno):
if not tbl[i][j]:
continue
txt = "".join([a["text"].strip() for a in tbl[i][j]])
if txt:
rtxt.append(txt)
if rtxt:
append("")
continue
for j in range(clmno):
if not tbl[i][j]:
continue
txt = "".join([a["text"].strip() for a in tbl[i][j]])
if not txt:
continue
ctt = headers[r][j] if r in headers else ""
if ctt:
ctt += ""
ctt += txt
if ctt:
rtxt.append(ctt)
if rtxt:
row_txt.append("; ".join(rtxt))
if cap:
if is_english:
from_ = " in "
else:
from_ = "来自"
row_txt = [t + f"\t——{from_}{cap}" for t in row_txt]
return row_txt
def __cal_spans(self, boxes, rows, cols, tbl, html=True):
# caculate span
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
for cln in cols]
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
for cln in cols]
rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
for row in rows]
rbtm = [np.mean([c.get("R_btm", c["bottom"])
for c in row]) for row in rows]
for b in boxes:
if "SP" not in b:
continue
b["colspan"] = [b["cn"]]
b["rowspan"] = [b["rn"]]
# col span
for j in range(0, len(clft)):
if j == b["cn"]:
continue
if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
continue
if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
continue
b["colspan"].append(j)
# row span
for j in range(0, len(rtop)):
if j == b["rn"]:
continue
if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
continue
if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
continue
b["rowspan"].append(j)
def join(arr):
if not arr:
return ""
return "".join([t["text"] for t in arr])
# rm the spaning cells
for i in range(len(tbl)):
for j, arr in enumerate(tbl[i]):
if not arr:
continue
if all(["rowspan" not in a and "colspan" not in a for a in arr]):
continue
rowspan, colspan = [], []
for a in arr:
if isinstance(a.get("rowspan", 0), list):
rowspan.extend(a["rowspan"])
if isinstance(a.get("colspan", 0), list):
colspan.extend(a["colspan"])
rowspan, colspan = set(rowspan), set(colspan)
if len(rowspan) < 2 and len(colspan) < 2:
for a in arr:
if "rowspan" in a:
del a["rowspan"]
if "colspan" in a:
del a["colspan"]
continue
rowspan, colspan = sorted(rowspan), sorted(colspan)
rowspan = list(range(rowspan[0], rowspan[-1] + 1))
colspan = list(range(colspan[0], colspan[-1] + 1))
assert i in rowspan, rowspan
assert j in colspan, colspan
arr = []
for r in rowspan:
for c in colspan:
arr_txt = join(arr)
if tbl[r][c] and join(tbl[r][c]) != arr_txt:
arr.extend(tbl[r][c])
tbl[r][c] = None if html else arr
for a in arr:
if len(rowspan) > 1:
a["rowspan"] = len(rowspan)
elif "rowspan" in a:
del a["rowspan"]
if len(colspan) > 1:
a["colspan"] = len(colspan)
elif "colspan" in a:
del a["colspan"]
tbl[rowspan[0]][colspan[0]] = arr
return tbl

View File

@ -1,2 +0,0 @@
from .ocr import OCR
from .recognizer import Recognizer

View File

@ -1,139 +0,0 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import onnxruntime as ort
from huggingface_hub import snapshot_download
from .operators import *
from rag.settings import cron_logger
class Recognizer(object):
def __init__(self, label_list, task_name, model_dir=None):
"""
If you have trouble downloading HuggingFace models, -_^ this might help!!
For Linux:
export HF_ENDPOINT=https://hf-mirror.com
For Windows:
Good luck
^_-
"""
if not model_dir:
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
model_file_path = os.path.join(model_dir, task_name + ".onnx")
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
if ort.get_device() == "GPU":
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
else:
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
self.label_list = label_list
def create_inputs(self, imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
im_shape = []
scale_factor = []
if len(imgs) == 1:
inputs['image'] = np.array((imgs[0],)).astype('float32')
inputs['im_shape'] = np.array(
(im_info[0]['im_shape'],)).astype('float32')
inputs['scale_factor'] = np.array(
(im_info[0]['scale_factor'],)).astype('float32')
return inputs
for e in im_info:
im_shape.append(np.array((e['im_shape'],)).astype('float32'))
scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
inputs['image'] = np.stack(padding_imgs, axis=0)
return inputs
def preprocess(self, image_list):
preprocess_ops = []
for op_info in [
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
{'type': 'Permute'},
{'stride': 32, 'type': 'PadStride'}
]:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
inputs = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
return inputs
def __call__(self, image_list, thr=0.7, batch_size=16):
res = []
imgs = []
for i in range(len(image_list)):
if not isinstance(image_list[i], np.ndarray):
imgs.append(np.array(image_list[i]))
else: imgs.append(image_list[i])
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
for i in range(batch_loop_cnt):
start_index = i * batch_size
end_index = min((i + 1) * batch_size, len(imgs))
batch_image_list = imgs[start_index:end_index]
inputs = self.preprocess(batch_image_list)
for ins in inputs:
bb = []
for b in self.ort_sess.run(None, ins)[0]:
clsid, bbox, score = int(b[0]), b[2:], b[1]
if score < thr:
continue
if clsid >= len(self.label_list):
cron_logger.warning(f"bad category id")
continue
bb.append({
"type": self.label_list[clsid].lower(),
"bbox": [float(t) for t in bbox.tolist()],
"score": float(score)
})
res.append(bb)
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
return res

View File

@ -21,7 +21,7 @@ from datetime import datetime
from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
from deepdoc.parser import HuParser
from deepdoc.parser import PdfParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm
@ -80,7 +80,7 @@ def dispatch():
tsks = []
if r["type"] == FileType.PDF.value:
pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for s,e in r["parser_config"].get("pages", [(0,100000)]):
e = min(e, pages)
for p in range(s, e, 10):