diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py
index e6e33d063..0d392bccf 100644
--- a/api/apps/conversation_app.py
+++ b/api/apps/conversation_app.py
@@ -34,7 +34,6 @@ from rag.utils import num_tokens_from_string, encoder, rmSpace
@manager.route('/set', methods=['POST'])
@login_required
-@validate_request("dialog_id")
def set_conversation():
req = request.json
conv_id = req.get("conversation_id")
@@ -145,7 +144,7 @@ def message_fit_in(msg, max_length=4000):
@manager.route('/completion', methods=['POST'])
@login_required
-@validate_request("dialog_id", "messages")
+@validate_request("conversation_id", "messages")
def completion():
req = request.json
msg = []
@@ -154,12 +153,20 @@ def completion():
if m["role"] == "assistant" and not msg: continue
msg.append({"role": m["role"], "content": m["content"]})
try:
- e, dia = DialogService.get_by_id(req["dialog_id"])
+ e, conv = ConversationService.get_by_id(req["conversation_id"])
+ if not e:
+ return get_data_error_result(retmsg="Conversation not found!")
+ conv.message.append(msg[-1])
+ e, dia = DialogService.get_by_id(conv.dialog_id)
if not e:
return get_data_error_result(retmsg="Dialog not found!")
- del req["dialog_id"]
+ del req["conversation_id"]
del req["messages"]
- return get_json_result(data=chat(dia, msg, **req))
+ ans = chat(dia, msg, **req)
+ conv.reference.append(ans["reference"])
+ conv.message.append({"role": "assistant", "content": ans["answer"]})
+ ConversationService.update_by_id(conv.id, conv.to_dict())
+ return get_json_result(data=ans)
except Exception as e:
return server_error_response(e)
@@ -194,8 +201,8 @@ def chat(dialog, messages, **kwargs):
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
- if not knowledges and prompt_config["empty_response"]:
- return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
+ if not knowledges and prompt_config.get("empty_response"):
+ return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n".join(knowledges)
gen_conf = dialog.llm_setting
@@ -205,7 +212,8 @@ def chat(dialog, messages, **kwargs):
gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
- answer = retrievaler.insert_citations(answer,
+ if knowledges:
+ answer = retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in kbinfos["chunks"]],
[ck["vector"] for ck in kbinfos["chunks"]],
embd_mdl,
@@ -213,7 +221,7 @@ def chat(dialog, messages, **kwargs):
vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]:
if c.get("vector"): del c["vector"]
- return {"answer": answer, "retrieval": kbinfos}
+ return {"answer": answer, "reference": kbinfos}
def use_sql(question, field_map, tenant_id, chat_mdl):
diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py
index c70f7ea44..0a1167b6f 100644
--- a/api/apps/llm_app.py
+++ b/api/apps/llm_app.py
@@ -94,11 +94,11 @@ def list():
model_type = request.args.get("model_type")
try:
objs = TenantLLMService.query(tenant_id=current_user.id)
- mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
+ facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
llms = LLMService.get_all()
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
for m in llms:
- m["available"] = m["llm_name"] in mdlnms
+ m["available"] = m["fid"] in facts
res = {}
for m in llms:
diff --git a/api/db/db_models.py b/api/db/db_models.py
index 0e032fcaa..49aa169cf 100644
--- a/api/db/db_models.py
+++ b/api/db/db_models.py
@@ -500,7 +500,7 @@ class Document(DataBaseModel):
token_num = IntegerField(default=0)
chunk_num = IntegerField(default=0)
progress = FloatField(default=0)
- progress_msg = CharField(max_length=512, null=True, help_text="process message", default="")
+ progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
process_begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
@@ -518,7 +518,7 @@ class Task(DataBaseModel):
begin_at = DateTimeField(null=True)
process_duation = FloatField(default=0)
progress = FloatField(default=0)
- progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
+ progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
class Dialog(DataBaseModel):
@@ -561,6 +561,7 @@ class Conversation(DataBaseModel):
dialog_id = CharField(max_length=32, null=False, index=True)
name = CharField(max_length=255, null=True, help_text="converastion name")
message = JSONField(null=True)
+ reference = JSONField(null=True, default=[])
class Meta:
db_table = "conversation"
diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py
index 0fb10b0a1..6bc1150d7 100644
--- a/api/db/services/llm_service.py
+++ b/api/db/services/llm_service.py
@@ -75,7 +75,7 @@ class TenantLLMService(CommonService):
model_config = cls.get_api_key(tenant_id, mdlnm)
if not model_config:
- raise LookupError("Model({}) not found".format(mdlnm))
+ raise LookupError("Model({}) not authorized".format(mdlnm))
model_config = model_config.to_dict()
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py
index 576687b18..c5d2e6a24 100644
--- a/deepdoc/parser/pdf_parser.py
+++ b/deepdoc/parser/pdf_parser.py
@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
-import os
import random
import fitz
-import requests
import xgboost as xgb
from io import BytesIO
import torch
@@ -14,9 +12,8 @@ from PIL import Image
import numpy as np
from api.db import ParserType
-from deepdoc.visual import OCR, Recognizer
+from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
from rag.nlp import huqie
-from collections import Counter
from copy import deepcopy
from huggingface_hub import hf_hub_download
@@ -29,29 +26,8 @@ class HuParser:
self.ocr = OCR()
if not hasattr(self, "model_speciess"):
self.model_speciess = ParserType.GENERAL.value
- self.layout_labels = [
- "_background_",
- "Text",
- "Title",
- "Figure",
- "Figure caption",
- "Table",
- "Table caption",
- "Header",
- "Footer",
- "Reference",
- "Equation",
- ]
- self.tsr_labels = [
- "table",
- "table column",
- "table row",
- "table column header",
- "table projected row header",
- "table spanning cell",
- ]
- self.layouter = Recognizer(self.layout_labels, "layout", "/data/newpeak/medical-gpt/res/ppdet/")
- self.tbl_det = Recognizer(self.tsr_labels, "tsr", "/data/newpeak/medical-gpt/res/ppdet.tbl/")
+ self.layouter = LayoutRecognizer("layout."+self.model_speciess)
+ self.tbl_det = TableStructureRecognizer()
self.updown_cnt_mdl = xgb.Booster()
if torch.cuda.is_available():
@@ -70,39 +46,6 @@ class HuParser:
"""
- def __remote_call(self, species, images, thr=0.7):
- url = os.environ.get("INFINIFLOW_SERVER")
- token = os.environ.get("INFINIFLOW_TOKEN")
- if not url or not token:
- logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
- return [[] for _ in range(len(images))]
-
- def convert_image_to_bytes(PILimage):
- image = BytesIO()
- PILimage.save(image, format='png')
- image.seek(0)
- return image.getvalue()
-
- images = [convert_image_to_bytes(img) for img in images]
-
- def remote_call():
- nonlocal images, thr
- res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr},
- headers={"Authorization": token}, timeout=len(images) * 10)
- res = res.json()
- if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
- return res["data"]
-
- for _ in range(3):
- try:
- return remote_call()
- except RuntimeError as e:
- raise e
- except Exception as e:
- logging.error("layout_predict:"+str(e))
- return remote_call()
-
-
def __char_width(self, c):
return (c["x1"] - c["x0"]) // len(c["text"])
@@ -188,20 +131,6 @@ class HuParser:
]
return fea
- @staticmethod
- def sort_Y_firstly(arr, threashold):
- # sort using y1 first and then x1
- arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
- for i in range(len(arr) - 1):
- for j in range(i, -1, -1):
- # restore the order using th
- if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
- and arr[j + 1]["x0"] < arr[j]["x0"]:
- tmp = deepcopy(arr[j])
- arr[j] = deepcopy(arr[j + 1])
- arr[j + 1] = deepcopy(tmp)
- return arr
-
@staticmethod
def sort_X_by_page(arr, threashold):
# sort using y1 first and then x1
@@ -217,61 +146,6 @@ class HuParser:
arr[j + 1] = tmp
return arr
- @staticmethod
- def sort_R_firstly(arr, thr=0):
- # sort using y1 first and then x1
- # sorted(arr, key=lambda r: (r["top"], r["x0"]))
- arr = HuParser.sort_Y_firstly(arr, thr)
- for i in range(len(arr) - 1):
- for j in range(i, -1, -1):
- if "R" not in arr[j] or "R" not in arr[j + 1]:
- continue
- if arr[j + 1]["R"] < arr[j]["R"] \
- or (
- arr[j + 1]["R"] == arr[j]["R"]
- and arr[j + 1]["x0"] < arr[j]["x0"]
- ):
- tmp = arr[j]
- arr[j] = arr[j + 1]
- arr[j + 1] = tmp
- return arr
-
- @staticmethod
- def sort_X_firstly(arr, threashold, copy=True):
- # sort using y1 first and then x1
- arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
- for i in range(len(arr) - 1):
- for j in range(i, -1, -1):
- # restore the order using th
- if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
- and arr[j + 1]["top"] < arr[j]["top"]:
- tmp = deepcopy(arr[j]) if copy else arr[j]
- arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
- arr[j + 1] = deepcopy(tmp) if copy else tmp
- return arr
-
- @staticmethod
- def sort_C_firstly(arr, thr=0):
- # sort using y1 first and then x1
- # sorted(arr, key=lambda r: (r["x0"], r["top"]))
- arr = HuParser.sort_X_firstly(arr, thr)
- for i in range(len(arr) - 1):
- for j in range(i, -1, -1):
- # restore the order using th
- if "C" not in arr[j] or "C" not in arr[j + 1]:
- continue
- if arr[j + 1]["C"] < arr[j]["C"] \
- or (
- arr[j + 1]["C"] == arr[j]["C"]
- and arr[j + 1]["top"] < arr[j]["top"]
- ):
- tmp = arr[j]
- arr[j] = arr[j + 1]
- arr[j + 1] = tmp
- return arr
-
- return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
-
def _has_color(self, o):
if o.get("ncs", "") == "DeviceGray":
if o["stroking_color"] and o["stroking_color"][0] == 1 and o["non_stroking_color"] and \
@@ -280,172 +154,6 @@ class HuParser:
return False
return True
- def __overlapped_area(self, a, b, ratio=True):
- tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
- if b["x0"] > x1 or b["x1"] < x0:
- return 0
- if b["bottom"] < tp or b["top"] > btm:
- return 0
- x0_ = max(b["x0"], x0)
- x1_ = min(b["x1"], x1)
- assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
- tp, btm, x0, x1, b)
- tp_ = max(b["top"], tp)
- btm_ = min(b["bottom"], btm)
- assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
- tp, btm, x0, x1, b)
- ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
- x0 != 0 and btm - tp != 0 else 0
- if ov > 0 and ratio:
- ov /= (x1 - x0) * (btm - tp)
- return ov
-
- def __find_overlapped_with_threashold(self, box, boxes, thr=0.3):
- if not boxes:
- return
- max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
- s, e = 0, len(boxes)
- for i in range(s, e):
- ov = self.__overlapped_area(box, boxes[i])
- _ov = self.__overlapped_area(boxes[i], box)
- if (ov, _ov) < (max_overlaped, _max_overlaped):
- continue
- max_overlaped_i = i
- max_overlaped = ov
- _max_overlaped = _ov
-
- return max_overlaped_i
-
- def __find_overlapped(self, box, boxes_sorted_by_y, naive=False):
- if not boxes_sorted_by_y:
- return
- bxs = boxes_sorted_by_y
- s, e, ii = 0, len(bxs), 0
- while s < e and not naive:
- ii = (e + s) // 2
- pv = bxs[ii]
- if box["bottom"] < pv["top"]:
- e = ii
- continue
- if box["top"] > pv["bottom"]:
- s = ii + 1
- continue
- break
- while s < ii:
- if box["top"] > bxs[s]["bottom"]:
- s += 1
- break
- while e - 1 > ii:
- if box["bottom"] < bxs[e - 1]["top"]:
- e -= 1
- break
-
- max_overlaped_i, max_overlaped = None, 0
- for i in range(s, e):
- ov = self.__overlapped_area(bxs[i], box)
- if ov <= max_overlaped:
- continue
- max_overlaped_i = i
- max_overlaped = ov
-
- return max_overlaped_i
-
- def _is_garbage(self, b):
- patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
- r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
- "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
- "\\(cid *: *[0-9]+ *\\)"
- ]
- return any([re.search(p, b["text"]) for p in patt])
-
- def __layouts_cleanup(self, boxes, layouts, far=2, thr=0.7):
- def notOverlapped(a, b):
- return any([a["x1"] < b["x0"],
- a["x0"] > b["x1"],
- a["bottom"] < b["top"],
- a["top"] > b["bottom"]])
-
- i = 0
- while i + 1 < len(layouts):
- j = i + 1
- while j < min(i + far, len(layouts)) \
- and (layouts[i].get("type", "") != layouts[j].get("type", "")
- or notOverlapped(layouts[i], layouts[j])):
- j += 1
- if j >= min(i + far, len(layouts)):
- i += 1
- continue
- if self.__overlapped_area(layouts[i], layouts[j]) < thr \
- and self.__overlapped_area(layouts[j], layouts[i]) < thr:
- i += 1
- continue
-
- if layouts[i].get("score") and layouts[j].get("score"):
- if layouts[i]["score"] > layouts[j]["score"]:
- layouts.pop(j)
- else:
- layouts.pop(i)
- continue
-
- area_i, area_i_1 = 0, 0
- for b in boxes:
- if not notOverlapped(b, layouts[i]):
- area_i += self.__overlapped_area(b, layouts[i], False)
- if not notOverlapped(b, layouts[j]):
- area_i_1 += self.__overlapped_area(b, layouts[j], False)
-
- if area_i > area_i_1:
- layouts.pop(j)
- else:
- layouts.pop(i)
-
- return layouts
-
- def __table_tsr(self, images):
- tbls = self.tbl_det(images, thr=0.5)
- res = []
- # align left&right for rows, align top&bottom for columns
- for tbl in tbls:
- lts = [{"label": b["type"],
- "score": b["score"],
- "x0": b["bbox"][0], "x1": b["bbox"][2],
- "top": b["bbox"][1], "bottom": b["bbox"][-1]
- } for b in tbl]
- if not lts:
- continue
-
- left = [b["x0"] for b in lts if b["label"].find(
- "row") > 0 or b["label"].find("header") > 0]
- right = [b["x1"] for b in lts if b["label"].find(
- "row") > 0 or b["label"].find("header") > 0]
- if not left:
- continue
- left = np.median(left) if len(left) > 4 else np.min(left)
- right = np.median(right) if len(right) > 4 else np.max(right)
- for b in lts:
- if b["label"].find("row") > 0 or b["label"].find("header") > 0:
- if b["x0"] > left:
- b["x0"] = left
- if b["x1"] < right:
- b["x1"] = right
-
- top = [b["top"] for b in lts if b["label"] == "table column"]
- bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
- if not top:
- res.append(lts)
- continue
- top = np.median(top) if len(top) > 4 else np.min(top)
- bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
- for b in lts:
- if b["label"] == "table column":
- if b["top"] > top:
- b["top"] = top
- if b["bottom"] < bottom:
- b["bottom"] = bottom
-
- res.append(lts)
- return res
-
def _table_transformer_job(self, ZM):
logging.info("Table processing...")
imgs, pos = [], []
@@ -471,7 +179,7 @@ class HuParser:
assert len(self.page_images) == len(tbcnt) - 1
if not imgs:
return
- recos = self.__table_tsr(imgs)
+ recos = self.tbl_det(imgs)
tbcnt = np.cumsum(tbcnt)
for i in range(len(tbcnt) - 1): # for page
pg = []
@@ -493,10 +201,10 @@ class HuParser:
self.tb_cpns.extend(pg)
def gather(kwd, fzy=10, ption=0.6):
- eles = self.sort_Y_firstly(
+ eles = Recognizer.sort_Y_firstly(
[r for r in self.tb_cpns if re.match(kwd, r["label"])], fzy)
- eles = self.__layouts_cleanup(self.boxes, eles, 5, ption)
- return self.sort_Y_firstly(eles, 0)
+ eles = Recognizer.layouts_cleanup(self.boxes, eles, 5, ption)
+ return Recognizer.sort_Y_firstly(eles, 0)
# add R,H,C,SP tag to boxes within table layout
headers = gather(r".*header$")
@@ -504,17 +212,17 @@ class HuParser:
spans = gather(r".*spanning")
clmns = sorted([r for r in self.tb_cpns if re.match(
r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
- clmns = self.__layouts_cleanup(self.boxes, clmns, 5, 0.5)
+ clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5)
for b in self.boxes:
if b.get("layout_type", "") != "table":
continue
- ii = self.__find_overlapped_with_threashold(b, rows, thr=0.3)
+ ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
if ii is not None:
b["R"] = ii
b["R_top"] = rows[ii]["top"]
b["R_bott"] = rows[ii]["bottom"]
- ii = self.__find_overlapped_with_threashold(b, headers, thr=0.3)
+ ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
if ii is not None:
b["H_top"] = headers[ii]["top"]
b["H_bott"] = headers[ii]["bottom"]
@@ -522,13 +230,13 @@ class HuParser:
b["H_right"] = headers[ii]["x1"]
b["H"] = ii
- ii = self.__find_overlapped_with_threashold(b, clmns, thr=0.3)
+ ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
if ii is not None:
b["C"] = ii
b["C_left"] = clmns[ii]["x0"]
b["C_right"] = clmns[ii]["x1"]
- ii = self.__find_overlapped_with_threashold(b, spans, thr=0.3)
+ ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
if ii is not None:
b["H_top"] = spans[ii]["top"]
b["H_bott"] = spans[ii]["bottom"]
@@ -542,7 +250,7 @@ class HuParser:
self.boxes.append([])
return
bxs = [(line[0], line[1][0]) for line in bxs]
- bxs = self.sort_Y_firstly(
+ bxs = Recognizer.sort_Y_firstly(
[{"x0": b[0][0] / ZM, "x1": b[1][0] / ZM,
"top": b[0][1] / ZM, "text": "", "txt": t,
"bottom": b[-1][1] / ZM,
@@ -551,8 +259,8 @@ class HuParser:
)
# merge chars in the same rect
- for c in self.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4):
- ii = self.__find_overlapped(c, bxs)
+ for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4):
+ ii = Recognizer.find_overlapped(c, bxs)
if ii is None:
self.lefted_chars.append(c)
continue
@@ -573,91 +281,11 @@ class HuParser:
if self.mean_height[-1] == 0:
self.mean_height[-1] = np.median([b["bottom"] - b["top"]
for b in bxs])
-
self.boxes.append(bxs)
def _layouts_rec(self, ZM):
assert len(self.page_images) == len(self.boxes)
- # Tag layout type
- boxes = []
- layouts = self.layouter(self.page_images)
- #save_results(self.page_images, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
- assert len(self.page_images) == len(layouts)
- for pn, lts in enumerate(layouts):
- bxs = self.boxes[pn]
- lts = [{"type": b["type"],
- "score": float(b["score"]),
- "x0": b["bbox"][0] / ZM, "x1": b["bbox"][2] / ZM,
- "top": b["bbox"][1] / ZM, "bottom": b["bbox"][-1] / ZM,
- "page_number": pn,
- } for b in lts]
- lts = self.sort_Y_firstly(lts, self.mean_height[pn] / 2)
- lts = self.__layouts_cleanup(bxs, lts)
- self.page_layout.append(lts)
-
- # Tag layout type, layouts are ready
- def findLayout(ty):
- nonlocal bxs, lts
- lts_ = [lt for lt in lts if lt["type"] == ty]
- i = 0
- while i < len(bxs):
- if bxs[i].get("layout_type"):
- i += 1
- continue
- if self._is_garbage(bxs[i]):
- logging.debug("GARBAGE: " + bxs[i]["text"])
- bxs.pop(i)
- continue
-
- ii = self.__find_overlapped_with_threashold(bxs[i], lts_,
- thr=0.4)
- if ii is None: # belong to nothing
- bxs[i]["layout_type"] = ""
- i += 1
- continue
- lts_[ii]["visited"] = True
- if lts_[ii]["type"] in ["footer", "header", "reference"]:
- if lts_[ii]["type"] not in self.garbages:
- self.garbages[lts_[ii]["type"]] = []
- self.garbages[lts_[ii]["type"]].append(bxs[i]["text"])
- logging.debug("GARBAGE: " + bxs[i]["text"])
- bxs.pop(i)
- continue
-
- bxs[i]["layoutno"] = f"{ty}-{ii}"
- bxs[i]["layout_type"] = lts_[ii]["type"]
- i += 1
-
- for lt in ["footer", "header", "reference", "figure caption",
- "table caption", "title", "text", "table", "figure"]:
- findLayout(lt)
-
- # add box to figure layouts which has not text box
- for i, lt in enumerate(
- [lt for lt in lts if lt["type"] == "figure"]):
- if lt.get("visited"):
- continue
- lt = deepcopy(lt)
- del lt["type"]
- lt["text"] = ""
- lt["layout_type"] = "figure"
- lt["layoutno"] = f"figure-{i}"
- bxs.append(lt)
-
- boxes.extend(bxs)
-
- self.boxes = boxes
-
- garbage = set()
- for k in self.garbages.keys():
- self.garbages[k] = Counter(self.garbages[k])
- for g, c in self.garbages[k].items():
- if c > 1:
- garbage.add(g)
-
- logging.debug("GARBAGE:" + ",".join(garbage))
- self.boxes = [b for b in self.boxes if b["text"].strip() not in garbage]
-
+ self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM)
# cumlative Y
for i in range(len(self.boxes)):
self.boxes[i]["top"] += \
@@ -710,7 +338,7 @@ class HuParser:
self.boxes = bxs
def _naive_vertical_merge(self):
- bxs = self.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3)
+ bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3)
i = 0
while i + 1 < len(bxs):
b = bxs[i]
@@ -850,7 +478,7 @@ class HuParser:
t["layout_type"] = c["layout_type"]
boxes.append(t)
- self.boxes = self.sort_Y_firstly(boxes, 0)
+ self.boxes = Recognizer.sort_Y_firstly(boxes, 0)
def _filter_forpages(self):
if not self.boxes:
@@ -916,492 +544,6 @@ class HuParser:
b_["top"] = b["top"]
self.boxes.pop(i)
- def _blockType(self, b):
- patt = [
- ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
- (r"^(20|19)[0-9]{2}年$", "Dt"),
- (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
- ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
- (r"^第*[一二三四1-4]季度$", "Dt"),
- (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
- (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
- ("^[0-9.,+%/ -]+$", "Nu"),
- (r"^[0-9A-Z/\._~-]+$", "Ca"),
- (r"^[A-Z]*[a-z' -]+$", "En"),
- (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
- (r"^.{1}$", "Sg")
- ]
- for p, n in patt:
- if re.search(p, b["text"].strip()):
- return n
- tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
- if len(tks) > 3:
- if len(tks) < 12:
- return "Tx"
- else:
- return "Lx"
-
- if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
- return "Nr"
-
- return "Ot"
-
- def __cal_spans(self, boxes, rows, cols, tbl, html=True):
- # caculate span
- clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
- for cln in cols]
- crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
- for cln in cols]
- rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
- for row in rows]
- rbtm = [np.mean([c.get("R_btm", c["bottom"])
- for c in row]) for row in rows]
- for b in boxes:
- if "SP" not in b:
- continue
- b["colspan"] = [b["cn"]]
- b["rowspan"] = [b["rn"]]
- # col span
- for j in range(0, len(clft)):
- if j == b["cn"]:
- continue
- if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
- continue
- if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
- continue
- b["colspan"].append(j)
- # row span
- for j in range(0, len(rtop)):
- if j == b["rn"]:
- continue
- if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
- continue
- if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
- continue
- b["rowspan"].append(j)
-
- def join(arr):
- if not arr:
- return ""
- return "".join([t["text"] for t in arr])
-
- # rm the spaning cells
- for i in range(len(tbl)):
- for j, arr in enumerate(tbl[i]):
- if not arr:
- continue
- if all(["rowspan" not in a and "colspan" not in a for a in arr]):
- continue
- rowspan, colspan = [], []
- for a in arr:
- if isinstance(a.get("rowspan", 0), list):
- rowspan.extend(a["rowspan"])
- if isinstance(a.get("colspan", 0), list):
- colspan.extend(a["colspan"])
- rowspan, colspan = set(rowspan), set(colspan)
- if len(rowspan) < 2 and len(colspan) < 2:
- for a in arr:
- if "rowspan" in a:
- del a["rowspan"]
- if "colspan" in a:
- del a["colspan"]
- continue
- rowspan, colspan = sorted(rowspan), sorted(colspan)
- rowspan = list(range(rowspan[0], rowspan[-1] + 1))
- colspan = list(range(colspan[0], colspan[-1] + 1))
- assert i in rowspan, rowspan
- assert j in colspan, colspan
- arr = []
- for r in rowspan:
- for c in colspan:
- arr_txt = join(arr)
- if tbl[r][c] and join(tbl[r][c]) != arr_txt:
- arr.extend(tbl[r][c])
- tbl[r][c] = None if html else arr
- for a in arr:
- if len(rowspan) > 1:
- a["rowspan"] = len(rowspan)
- elif "rowspan" in a:
- del a["rowspan"]
- if len(colspan) > 1:
- a["colspan"] = len(colspan)
- elif "colspan" in a:
- del a["colspan"]
- tbl[rowspan[0]][colspan[0]] = arr
-
- return tbl
-
- def __construct_table(self, boxes, html=False):
- cap = ""
- i = 0
- while i < len(boxes):
- if self.is_caption(boxes[i]):
- cap += boxes[i]["text"]
- boxes.pop(i)
- i -= 1
- i += 1
-
- if not boxes:
- return []
- for b in boxes:
- b["btype"] = self._blockType(b)
- max_type = Counter([b["btype"] for b in boxes]).items()
- max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
- logging.debug("MAXTYPE: " + max_type)
-
- rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
- rowh = np.min(rowh) if rowh else 0
- # boxes = self.sort_Y_firstly(boxes, rowh/5)
- boxes = self.sort_R_firstly(boxes, rowh / 2)
- boxes[0]["rn"] = 0
- rows = [[boxes[0]]]
- btm = boxes[0]["bottom"]
- for b in boxes[1:]:
- b["rn"] = len(rows) - 1
- lst_r = rows[-1]
- if lst_r[-1].get("R", "") != b.get("R", "") \
- or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
- ): # new row
- btm = b["bottom"]
- b["rn"] += 1
- rows.append([b])
- continue
- btm = (btm + b["bottom"]) / 2.
- rows[-1].append(b)
-
- colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
- colwm = np.min(colwm) if colwm else 0
- crosspage = len(set([b["page_number"] for b in boxes])) > 1
- if crosspage:
- boxes = self.sort_X_firstly(boxes, colwm / 2, False)
- else:
- boxes = self.sort_C_firstly(boxes, colwm / 2)
- boxes[0]["cn"] = 0
- cols = [[boxes[0]]]
- right = boxes[0]["x1"]
- for b in boxes[1:]:
- b["cn"] = len(cols) - 1
- lst_c = cols[-1]
- if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
- "page_number"]) \
- or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
- right = b["x1"]
- b["cn"] += 1
- cols.append([b])
- continue
- right = (right + b["x1"]) / 2.
- cols[-1].append(b)
-
- tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
- for b in boxes:
- tbl[b["rn"]][b["cn"]].append(b)
-
- if len(rows) >= 4:
- # remove single in column
- j = 0
- while j < len(tbl[0]):
- e, ii = 0, 0
- for i in range(len(tbl)):
- if tbl[i][j]:
- e += 1
- ii = i
- if e > 1:
- break
- if e > 1:
- j += 1
- continue
- f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
- [j - 1][0].get("text")) or j == 0
- ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
- [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
- if f and ff:
- j += 1
- continue
- bx = tbl[ii][j][0]
- logging.debug("Relocate column single: " + bx["text"])
- # j column only has one value
- left, right = 100000, 100000
- if j > 0 and not f:
- for i in range(len(tbl)):
- if tbl[i][j - 1]:
- left = min(left, np.min(
- [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
- if j + 1 < len(tbl[0]) and not ff:
- for i in range(len(tbl)):
- if tbl[i][j + 1]:
- right = min(right, np.min(
- [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
- assert left < 100000 or right < 100000
- if left < right:
- for jj in range(j, len(tbl[0])):
- for i in range(len(tbl)):
- for a in tbl[i][jj]:
- a["cn"] -= 1
- if tbl[ii][j - 1]:
- tbl[ii][j - 1].extend(tbl[ii][j])
- else:
- tbl[ii][j - 1] = tbl[ii][j]
- for i in range(len(tbl)):
- tbl[i].pop(j)
-
- else:
- for jj in range(j + 1, len(tbl[0])):
- for i in range(len(tbl)):
- for a in tbl[i][jj]:
- a["cn"] -= 1
- if tbl[ii][j + 1]:
- tbl[ii][j + 1].extend(tbl[ii][j])
- else:
- tbl[ii][j + 1] = tbl[ii][j]
- for i in range(len(tbl)):
- tbl[i].pop(j)
- cols.pop(j)
- assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
- len(cols), len(tbl[0]))
-
- if len(cols) >= 4:
- # remove single in row
- i = 0
- while i < len(tbl):
- e, jj = 0, 0
- for j in range(len(tbl[i])):
- if tbl[i][j]:
- e += 1
- jj = j
- if e > 1:
- break
- if e > 1:
- i += 1
- continue
- f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
- [jj][0].get("text")) or i == 0
- ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
- [jj][0].get("text")) or i + 1 >= len(tbl)
- if f and ff:
- i += 1
- continue
-
- bx = tbl[i][jj][0]
- logging.debug("Relocate row single: " + bx["text"])
- # i row only has one value
- up, down = 100000, 100000
- if i > 0 and not f:
- for j in range(len(tbl[i - 1])):
- if tbl[i - 1][j]:
- up = min(up, np.min(
- [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
- if i + 1 < len(tbl) and not ff:
- for j in range(len(tbl[i + 1])):
- if tbl[i + 1][j]:
- down = min(down, np.min(
- [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
- assert up < 100000 or down < 100000
- if up < down:
- for ii in range(i, len(tbl)):
- for j in range(len(tbl[ii])):
- for a in tbl[ii][j]:
- a["rn"] -= 1
- if tbl[i - 1][jj]:
- tbl[i - 1][jj].extend(tbl[i][jj])
- else:
- tbl[i - 1][jj] = tbl[i][jj]
- tbl.pop(i)
-
- else:
- for ii in range(i + 1, len(tbl)):
- for j in range(len(tbl[ii])):
- for a in tbl[ii][j]:
- a["rn"] -= 1
- if tbl[i + 1][jj]:
- tbl[i + 1][jj].extend(tbl[i][jj])
- else:
- tbl[i + 1][jj] = tbl[i][jj]
- tbl.pop(i)
- rows.pop(i)
-
- # which rows are headers
- hdset = set([])
- for i in range(len(tbl)):
- cnt, h = 0, 0
- for j, arr in enumerate(tbl[i]):
- if not arr:
- continue
- cnt += 1
- if max_type == "Nu" and arr[0]["btype"] == "Nu":
- continue
- if any([a.get("H") for a in arr]) \
- or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
- h += 1
- if h / cnt > 0.5:
- hdset.add(i)
-
- if html:
- return [self.__html_table(cap, hdset,
- self.__cal_spans(boxes, rows,
- cols, tbl, True)
- )]
-
- return self.__desc_table(cap, hdset,
- self.__cal_spans(boxes, rows, cols, tbl, False))
-
- def __html_table(self, cap, hdset, tbl):
- # constrcut HTML
- html = "
"
- if cap:
- html += f"{cap}"
- for i in range(len(tbl)):
- row = ""
- txts = []
- for j, arr in enumerate(tbl[i]):
- if arr is None:
- continue
- if not arr:
- row += " | " if i not in hdset else " | "
- continue
- txt = ""
- if arr:
- h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2,
- self.mean_height[arr[0]["page_number"] - 1] / 2)
- txt = "".join([c["text"]
- for c in self.sort_Y_firstly(arr, h)])
- txts.append(txt)
- sp = ""
- if arr[0].get("colspan"):
- sp = "colspan={}".format(arr[0]["colspan"])
- if arr[0].get("rowspan"):
- sp += " rowspan={}".format(arr[0]["rowspan"])
- if i in hdset:
- row += f"" + txt + " | "
- else:
- row += f"" + txt + " | "
-
- if i in hdset:
- if all([t in hdset for t in txts]):
- continue
- for t in txts:
- hdset.add(t)
-
- if row != "
":
- row += "
"
- else:
- row = ""
- html += "\n" + row
- html += "\n
"
- return html
-
- def __desc_table(self, cap, hdr_rowno, tbl):
- # get text of every colomn in header row to become header text
- clmno = len(tbl[0])
- rowno = len(tbl)
- headers = {}
- hdrset = set()
- lst_hdr = []
- de = "的" if not self.is_english else " for "
- for r in sorted(list(hdr_rowno)):
- headers[r] = ["" for _ in range(clmno)]
- for i in range(clmno):
- if not tbl[r][i]:
- continue
- txt = "".join([a["text"].strip() for a in tbl[r][i]])
- headers[r][i] = txt
- hdrset.add(txt)
- if all([not t for t in headers[r]]):
- del headers[r]
- hdr_rowno.remove(r)
- continue
- for j in range(clmno):
- if headers[r][j]:
- continue
- if j >= len(lst_hdr):
- break
- headers[r][j] = lst_hdr[j]
- lst_hdr = headers[r]
- for i in range(rowno):
- if i not in hdr_rowno:
- continue
- for j in range(i + 1, rowno):
- if j not in hdr_rowno:
- break
- for k in range(clmno):
- if not headers[j - 1][k]:
- continue
- if headers[j][k].find(headers[j - 1][k]) >= 0:
- continue
- if len(headers[j][k]) > len(headers[j - 1][k]):
- headers[j][k] += (de if headers[j][k]
- else "") + headers[j - 1][k]
- else:
- headers[j][k] = headers[j - 1][k] \
- + (de if headers[j - 1][k] else "") \
- + headers[j][k]
-
- logging.debug(
- f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
- row_txt = []
- for i in range(rowno):
- if i in hdr_rowno:
- continue
- rtxt = []
-
- def append(delimer):
- nonlocal rtxt, row_txt
- rtxt = delimer.join(rtxt)
- if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
- row_txt[-1] += "\n" + rtxt
- else:
- row_txt.append(rtxt)
-
- r = 0
- if len(headers.items()):
- _arr = [(i - r, r) for r, _ in headers.items() if r < i]
- if _arr:
- _, r = min(_arr, key=lambda x: x[0])
-
- if r not in headers and clmno <= 2:
- for j in range(clmno):
- if not tbl[i][j]:
- continue
- txt = "".join([a["text"].strip() for a in tbl[i][j]])
- if txt:
- rtxt.append(txt)
- if rtxt:
- append(":")
- continue
-
- for j in range(clmno):
- if not tbl[i][j]:
- continue
- txt = "".join([a["text"].strip() for a in tbl[i][j]])
- if not txt:
- continue
- ctt = headers[r][j] if r in headers else ""
- if ctt:
- ctt += ":"
- ctt += txt
- if ctt:
- rtxt.append(ctt)
-
- if rtxt:
- row_txt.append("; ".join(rtxt))
-
- if cap:
- if self.is_english:
- from_ = " in "
- else:
- from_ = "来自"
- row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
- return row_txt
-
- @staticmethod
- def is_caption(bx):
- patt = [
- r"[图表]+[ 0-9::]{2,}"
- ]
- if any([re.match(p, bx["text"].strip()) for p in patt]) \
- or bx["layout_type"].find("caption") >= 0:
- return True
- return False
-
def _extract_table_figure(self, need_image, ZM, return_html):
tables = {}
figures = {}
@@ -1415,7 +557,7 @@ class HuParser:
continue
lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"])
- if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
+ if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
"figure caption", "reference"]:
nomerge_lout_no.append(lst_lout_no)
if self.boxes[i]["layout_type"] == "table":
@@ -1470,7 +612,7 @@ class HuParser:
while i < len(self.boxes):
c = self.boxes[i]
# mh = self.mean_height[c["page_number"]-1]
- if not self.is_caption(c):
+ if not TableStructureRecognizer.is_caption(c):
i += 1
continue
@@ -1529,7 +671,7 @@ class HuParser:
"bottom": np.max([b["bottom"] for b in bxs]) - ht
}
louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
- ii = self.__find_overlapped(b, louts, naive=True)
+ ii = Recognizer.find_overlapped(b, louts, naive=True)
if ii is not None:
b = louts[ii]
else:
@@ -1581,7 +723,7 @@ class HuParser:
if not bxs:
continue
res.append((cropout(bxs, "table"),
- self.__construct_table(bxs, html=return_html)))
+ self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
return res
diff --git a/deepdoc/vision/__init__.py b/deepdoc/vision/__init__.py
new file mode 100644
index 000000000..93eea132c
--- /dev/null
+++ b/deepdoc/vision/__init__.py
@@ -0,0 +1,4 @@
+from .ocr import OCR
+from .recognizer import Recognizer
+from .layout_recognizer import LayoutRecognizer
+from .table_structure_recognizer import TableStructureRecognizer
diff --git a/deepdoc/vision/layout_recognizer.py b/deepdoc/vision/layout_recognizer.py
new file mode 100644
index 000000000..1a5795a2e
--- /dev/null
+++ b/deepdoc/vision/layout_recognizer.py
@@ -0,0 +1,119 @@
+import os
+import re
+from collections import Counter
+from copy import deepcopy
+
+import numpy as np
+
+from api.utils.file_utils import get_project_base_directory
+from .recognizer import Recognizer
+
+
+class LayoutRecognizer(Recognizer):
+ def __init__(self, domain):
+ self.layout_labels = [
+ "_background_",
+ "Text",
+ "Title",
+ "Figure",
+ "Figure caption",
+ "Table",
+ "Table caption",
+ "Header",
+ "Footer",
+ "Reference",
+ "Equation",
+ ]
+ super().__init__(self.layout_labels, domain,
+ os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
+
+ def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
+ def __is_garbage(b):
+ patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
+ r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
+ "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
+ "\\(cid *: *[0-9]+ *\\)"
+ ]
+ return any([re.search(p, b["text"]) for p in patt])
+
+ layouts = super().__call__(image_list, thr, batch_size)
+ # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
+ assert len(image_list) == len(ocr_res)
+ # Tag layout type
+ boxes = []
+ assert len(image_list) == len(layouts)
+ garbages = {}
+ page_layout = []
+ for pn, lts in enumerate(layouts):
+ bxs = ocr_res[pn]
+ lts = [{"type": b["type"],
+ "score": float(b["score"]),
+ "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
+ "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
+ "page_number": pn,
+ } for b in lts]
+ lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
+ lts = self.layouts_cleanup(bxs, lts)
+ page_layout.append(lts)
+
+ # Tag layout type, layouts are ready
+ def findLayout(ty):
+ nonlocal bxs, lts, self
+ lts_ = [lt for lt in lts if lt["type"] == ty]
+ i = 0
+ while i < len(bxs):
+ if bxs[i].get("layout_type"):
+ i += 1
+ continue
+ if __is_garbage(bxs[i]):
+ bxs.pop(i)
+ continue
+
+ ii = self.find_overlapped_with_threashold(bxs[i], lts_,
+ thr=0.4)
+ if ii is None: # belong to nothing
+ bxs[i]["layout_type"] = ""
+ i += 1
+ continue
+ lts_[ii]["visited"] = True
+ if lts_[ii]["type"] in ["footer", "header", "reference"]:
+ if lts_[ii]["type"] not in garbages:
+ garbages[lts_[ii]["type"]] = []
+ garbages[lts_[ii]["type"]].append(bxs[i]["text"])
+ bxs.pop(i)
+ continue
+
+ bxs[i]["layoutno"] = f"{ty}-{ii}"
+ bxs[i]["layout_type"] = lts_[ii]["type"]
+ i += 1
+
+ for lt in ["footer", "header", "reference", "figure caption",
+ "table caption", "title", "text", "table", "figure", "equation"]:
+ findLayout(lt)
+
+ # add box to figure layouts which has not text box
+ for i, lt in enumerate(
+ [lt for lt in lts if lt["type"] == "figure"]):
+ if lt.get("visited"):
+ continue
+ lt = deepcopy(lt)
+ del lt["type"]
+ lt["text"] = ""
+ lt["layout_type"] = "figure"
+ lt["layoutno"] = f"figure-{i}"
+ bxs.append(lt)
+
+ boxes.extend(bxs)
+
+ ocr_res = boxes
+
+ garbag_set = set()
+ for k in garbages.keys():
+ garbages[k] = Counter(garbages[k])
+ for g, c in garbages[k].items():
+ if c > 1:
+ garbag_set.add(g)
+
+ ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
+ return ocr_res, page_layout
+
diff --git a/deepdoc/visual/ocr.py b/deepdoc/vision/ocr.py
similarity index 99%
rename from deepdoc/visual/ocr.py
rename to deepdoc/vision/ocr.py
index 65b2c2ddf..09c8a7a91 100644
--- a/deepdoc/visual/ocr.py
+++ b/deepdoc/vision/ocr.py
@@ -74,7 +74,7 @@ class TextRecognizer(object):
self.rec_batch_num = 16
postprocess_params = {
'name': 'CTCLabelDecode',
- "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"),
+ "character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"),
"use_space_char": True
}
self.postprocess_op = build_post_process(postprocess_params)
@@ -450,7 +450,7 @@ class OCR(object):
"""
if not model_dir:
- model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
+ model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
self.text_detector = TextDetector(model_dir)
self.text_recognizer = TextRecognizer(model_dir)
diff --git a/deepdoc/visual/ocr.res b/deepdoc/vision/ocr.res
similarity index 100%
rename from deepdoc/visual/ocr.res
rename to deepdoc/vision/ocr.res
diff --git a/deepdoc/visual/operators.py b/deepdoc/vision/operators.py
similarity index 100%
rename from deepdoc/visual/operators.py
rename to deepdoc/vision/operators.py
diff --git a/deepdoc/visual/postprocess.py b/deepdoc/vision/postprocess.py
similarity index 100%
rename from deepdoc/visual/postprocess.py
rename to deepdoc/vision/postprocess.py
diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py
new file mode 100644
index 000000000..932923402
--- /dev/null
+++ b/deepdoc/vision/recognizer.py
@@ -0,0 +1,327 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from copy import deepcopy
+
+import onnxruntime as ort
+from huggingface_hub import snapshot_download
+
+from . import seeit
+from .operators import *
+from rag.settings import cron_logger
+
+
+class Recognizer(object):
+ def __init__(self, label_list, task_name, model_dir=None):
+ """
+ If you have trouble downloading HuggingFace models, -_^ this might help!!
+
+ For Linux:
+ export HF_ENDPOINT=https://hf-mirror.com
+
+ For Windows:
+ Good luck
+ ^_-
+
+ """
+ if not model_dir:
+ model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
+
+ model_file_path = os.path.join(model_dir, task_name + ".onnx")
+ if not os.path.exists(model_file_path):
+ raise ValueError("not find model file path {}".format(
+ model_file_path))
+ if ort.get_device() == "GPU":
+ self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
+ else:
+ self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
+ self.label_list = label_list
+
+ @staticmethod
+ def sort_Y_firstly(arr, threashold):
+ # sort using y1 first and then x1
+ arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
+ for i in range(len(arr) - 1):
+ for j in range(i, -1, -1):
+ # restore the order using th
+ if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
+ and arr[j + 1]["x0"] < arr[j]["x0"]:
+ tmp = deepcopy(arr[j])
+ arr[j] = deepcopy(arr[j + 1])
+ arr[j + 1] = deepcopy(tmp)
+ return arr
+
+ @staticmethod
+ def sort_X_firstly(arr, threashold, copy=True):
+ # sort using y1 first and then x1
+ arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
+ for i in range(len(arr) - 1):
+ for j in range(i, -1, -1):
+ # restore the order using th
+ if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
+ and arr[j + 1]["top"] < arr[j]["top"]:
+ tmp = deepcopy(arr[j]) if copy else arr[j]
+ arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
+ arr[j + 1] = deepcopy(tmp) if copy else tmp
+ return arr
+
+ @staticmethod
+ def sort_C_firstly(arr, thr=0):
+ # sort using y1 first and then x1
+ # sorted(arr, key=lambda r: (r["x0"], r["top"]))
+ arr = Recognizer.sort_X_firstly(arr, thr)
+ for i in range(len(arr) - 1):
+ for j in range(i, -1, -1):
+ # restore the order using th
+ if "C" not in arr[j] or "C" not in arr[j + 1]:
+ continue
+ if arr[j + 1]["C"] < arr[j]["C"] \
+ or (
+ arr[j + 1]["C"] == arr[j]["C"]
+ and arr[j + 1]["top"] < arr[j]["top"]
+ ):
+ tmp = arr[j]
+ arr[j] = arr[j + 1]
+ arr[j + 1] = tmp
+ return arr
+
+ return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
+
+ @staticmethod
+ def sort_R_firstly(arr, thr=0):
+ # sort using y1 first and then x1
+ # sorted(arr, key=lambda r: (r["top"], r["x0"]))
+ arr = Recognizer.sort_Y_firstly(arr, thr)
+ for i in range(len(arr) - 1):
+ for j in range(i, -1, -1):
+ if "R" not in arr[j] or "R" not in arr[j + 1]:
+ continue
+ if arr[j + 1]["R"] < arr[j]["R"] \
+ or (
+ arr[j + 1]["R"] == arr[j]["R"]
+ and arr[j + 1]["x0"] < arr[j]["x0"]
+ ):
+ tmp = arr[j]
+ arr[j] = arr[j + 1]
+ arr[j + 1] = tmp
+ return arr
+
+ @staticmethod
+ def overlapped_area(a, b, ratio=True):
+ tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
+ if b["x0"] > x1 or b["x1"] < x0:
+ return 0
+ if b["bottom"] < tp or b["top"] > btm:
+ return 0
+ x0_ = max(b["x0"], x0)
+ x1_ = min(b["x1"], x1)
+ assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
+ tp, btm, x0, x1, b)
+ tp_ = max(b["top"], tp)
+ btm_ = min(b["bottom"], btm)
+ assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
+ tp, btm, x0, x1, b)
+ ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
+ x0 != 0 and btm - tp != 0 else 0
+ if ov > 0 and ratio:
+ ov /= (x1 - x0) * (btm - tp)
+ return ov
+
+ @staticmethod
+ def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
+ def notOverlapped(a, b):
+ return any([a["x1"] < b["x0"],
+ a["x0"] > b["x1"],
+ a["bottom"] < b["top"],
+ a["top"] > b["bottom"]])
+
+ i = 0
+ while i + 1 < len(layouts):
+ j = i + 1
+ while j < min(i + far, len(layouts)) \
+ and (layouts[i].get("type", "") != layouts[j].get("type", "")
+ or notOverlapped(layouts[i], layouts[j])):
+ j += 1
+ if j >= min(i + far, len(layouts)):
+ i += 1
+ continue
+ if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
+ and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
+ i += 1
+ continue
+
+ if layouts[i].get("score") and layouts[j].get("score"):
+ if layouts[i]["score"] > layouts[j]["score"]:
+ layouts.pop(j)
+ else:
+ layouts.pop(i)
+ continue
+
+ area_i, area_i_1 = 0, 0
+ for b in boxes:
+ if not notOverlapped(b, layouts[i]):
+ area_i += Recognizer.overlapped_area(b, layouts[i], False)
+ if not notOverlapped(b, layouts[j]):
+ area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
+
+ if area_i > area_i_1:
+ layouts.pop(j)
+ else:
+ layouts.pop(i)
+
+ return layouts
+
+ def create_inputs(self, imgs, im_info):
+ """generate input for different model type
+ Args:
+ imgs (list(numpy)): list of images (np.ndarray)
+ im_info (list(dict)): list of image info
+ Returns:
+ inputs (dict): input of model
+ """
+ inputs = {}
+
+ im_shape = []
+ scale_factor = []
+ if len(imgs) == 1:
+ inputs['image'] = np.array((imgs[0],)).astype('float32')
+ inputs['im_shape'] = np.array(
+ (im_info[0]['im_shape'],)).astype('float32')
+ inputs['scale_factor'] = np.array(
+ (im_info[0]['scale_factor'],)).astype('float32')
+ return inputs
+
+ for e in im_info:
+ im_shape.append(np.array((e['im_shape'],)).astype('float32'))
+ scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
+
+ inputs['im_shape'] = np.concatenate(im_shape, axis=0)
+ inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
+
+ imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
+ max_shape_h = max([e[0] for e in imgs_shape])
+ max_shape_w = max([e[1] for e in imgs_shape])
+ padding_imgs = []
+ for img in imgs:
+ im_c, im_h, im_w = img.shape[:]
+ padding_im = np.zeros(
+ (im_c, max_shape_h, max_shape_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = img
+ padding_imgs.append(padding_im)
+ inputs['image'] = np.stack(padding_imgs, axis=0)
+ return inputs
+
+ @staticmethod
+ def find_overlapped(box, boxes_sorted_by_y, naive=False):
+ if not boxes_sorted_by_y:
+ return
+ bxs = boxes_sorted_by_y
+ s, e, ii = 0, len(bxs), 0
+ while s < e and not naive:
+ ii = (e + s) // 2
+ pv = bxs[ii]
+ if box["bottom"] < pv["top"]:
+ e = ii
+ continue
+ if box["top"] > pv["bottom"]:
+ s = ii + 1
+ continue
+ break
+ while s < ii:
+ if box["top"] > bxs[s]["bottom"]:
+ s += 1
+ break
+ while e - 1 > ii:
+ if box["bottom"] < bxs[e - 1]["top"]:
+ e -= 1
+ break
+
+ max_overlaped_i, max_overlaped = None, 0
+ for i in range(s, e):
+ ov = Recognizer.overlapped_area(bxs[i], box)
+ if ov <= max_overlaped:
+ continue
+ max_overlaped_i = i
+ max_overlaped = ov
+
+ return max_overlaped_i
+
+ @staticmethod
+ def find_overlapped_with_threashold(box, boxes, thr=0.3):
+ if not boxes:
+ return
+ max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
+ s, e = 0, len(boxes)
+ for i in range(s, e):
+ ov = Recognizer.overlapped_area(box, boxes[i])
+ _ov = Recognizer.overlapped_area(boxes[i], box)
+ if (ov, _ov) < (max_overlaped, _max_overlaped):
+ continue
+ max_overlaped_i = i
+ max_overlaped = ov
+ _max_overlaped = _ov
+
+ return max_overlaped_i
+
+ def preprocess(self, image_list):
+ preprocess_ops = []
+ for op_info in [
+ {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
+ {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
+ {'type': 'Permute'},
+ {'stride': 32, 'type': 'PadStride'}
+ ]:
+ new_op_info = op_info.copy()
+ op_type = new_op_info.pop('type')
+ preprocess_ops.append(eval(op_type)(**new_op_info))
+
+ inputs = []
+ for im_path in image_list:
+ im, im_info = preprocess(im_path, preprocess_ops)
+ inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
+ return inputs
+
+ def __call__(self, image_list, thr=0.7, batch_size=16):
+ res = []
+ imgs = []
+ for i in range(len(image_list)):
+ if not isinstance(image_list[i], np.ndarray):
+ imgs.append(np.array(image_list[i]))
+ else: imgs.append(image_list[i])
+
+ batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
+ for i in range(batch_loop_cnt):
+ start_index = i * batch_size
+ end_index = min((i + 1) * batch_size, len(imgs))
+ batch_image_list = imgs[start_index:end_index]
+ inputs = self.preprocess(batch_image_list)
+ for ins in inputs:
+ bb = []
+ for b in self.ort_sess.run(None, ins)[0]:
+ clsid, bbox, score = int(b[0]), b[2:], b[1]
+ if score < thr:
+ continue
+ if clsid >= len(self.label_list):
+ cron_logger.warning(f"bad category id")
+ continue
+ bb.append({
+ "type": self.label_list[clsid].lower(),
+ "bbox": [float(t) for t in bbox.tolist()],
+ "score": float(score)
+ })
+ res.append(bb)
+
+ #seeit.save_results(image_list, res, self.label_list, threshold=thr)
+
+ return res
diff --git a/deepdoc/visual/seeit.py b/deepdoc/vision/seeit.py
similarity index 100%
rename from deepdoc/visual/seeit.py
rename to deepdoc/vision/seeit.py
diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py
new file mode 100644
index 000000000..40366b1d4
--- /dev/null
+++ b/deepdoc/vision/table_structure_recognizer.py
@@ -0,0 +1,556 @@
+import logging
+import os
+import re
+from collections import Counter
+from copy import deepcopy
+
+import numpy as np
+
+from api.utils.file_utils import get_project_base_directory
+from rag.nlp import huqie
+from .recognizer import Recognizer
+
+
+class TableStructureRecognizer(Recognizer):
+ def __init__(self):
+ self.labels = [
+ "table",
+ "table column",
+ "table row",
+ "table column header",
+ "table projected row header",
+ "table spanning cell",
+ ]
+ super().__init__(self.labels, "tsr",
+ os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
+
+ def __call__(self, images, thr=0.5):
+ tbls = super().__call__(images, thr)
+ res = []
+ # align left&right for rows, align top&bottom for columns
+ for tbl in tbls:
+ lts = [{"label": b["type"],
+ "score": b["score"],
+ "x0": b["bbox"][0], "x1": b["bbox"][2],
+ "top": b["bbox"][1], "bottom": b["bbox"][-1]
+ } for b in tbl]
+ if not lts:
+ continue
+
+ left = [b["x0"] for b in lts if b["label"].find(
+ "row") > 0 or b["label"].find("header") > 0]
+ right = [b["x1"] for b in lts if b["label"].find(
+ "row") > 0 or b["label"].find("header") > 0]
+ if not left:
+ continue
+ left = np.median(left) if len(left) > 4 else np.min(left)
+ right = np.median(right) if len(right) > 4 else np.max(right)
+ for b in lts:
+ if b["label"].find("row") > 0 or b["label"].find("header") > 0:
+ if b["x0"] > left:
+ b["x0"] = left
+ if b["x1"] < right:
+ b["x1"] = right
+
+ top = [b["top"] for b in lts if b["label"] == "table column"]
+ bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
+ if not top:
+ res.append(lts)
+ continue
+ top = np.median(top) if len(top) > 4 else np.min(top)
+ bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
+ for b in lts:
+ if b["label"] == "table column":
+ if b["top"] > top:
+ b["top"] = top
+ if b["bottom"] < bottom:
+ b["bottom"] = bottom
+
+ res.append(lts)
+ return res
+
+ @staticmethod
+ def is_caption(bx):
+ patt = [
+ r"[图表]+[ 0-9::]{2,}"
+ ]
+ if any([re.match(p, bx["text"].strip()) for p in patt]) \
+ or bx["layout_type"].find("caption") >= 0:
+ return True
+ return False
+
+ def __blockType(self, b):
+ patt = [
+ ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
+ (r"^(20|19)[0-9]{2}年$", "Dt"),
+ (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
+ ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
+ (r"^第*[一二三四1-4]季度$", "Dt"),
+ (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
+ (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
+ ("^[0-9.,+%/ -]+$", "Nu"),
+ (r"^[0-9A-Z/\._~-]+$", "Ca"),
+ (r"^[A-Z]*[a-z' -]+$", "En"),
+ (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
+ (r"^.{1}$", "Sg")
+ ]
+ for p, n in patt:
+ if re.search(p, b["text"].strip()):
+ return n
+ tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
+ if len(tks) > 3:
+ if len(tks) < 12:
+ return "Tx"
+ else:
+ return "Lx"
+
+ if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
+ return "Nr"
+
+ return "Ot"
+
+ def construct_table(self, boxes, is_english=False, html=False):
+ cap = ""
+ i = 0
+ while i < len(boxes):
+ if self.is_caption(boxes[i]):
+ cap += boxes[i]["text"]
+ boxes.pop(i)
+ i -= 1
+ i += 1
+
+ if not boxes:
+ return []
+ for b in boxes:
+ b["btype"] = self.__blockType(b)
+ max_type = Counter([b["btype"] for b in boxes]).items()
+ max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
+ logging.debug("MAXTYPE: " + max_type)
+
+ rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
+ rowh = np.min(rowh) if rowh else 0
+ boxes = self.sort_R_firstly(boxes, rowh / 2)
+ boxes[0]["rn"] = 0
+ rows = [[boxes[0]]]
+ btm = boxes[0]["bottom"]
+ for b in boxes[1:]:
+ b["rn"] = len(rows) - 1
+ lst_r = rows[-1]
+ if lst_r[-1].get("R", "") != b.get("R", "") \
+ or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
+ ): # new row
+ btm = b["bottom"]
+ b["rn"] += 1
+ rows.append([b])
+ continue
+ btm = (btm + b["bottom"]) / 2.
+ rows[-1].append(b)
+
+ colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
+ colwm = np.min(colwm) if colwm else 0
+ crosspage = len(set([b["page_number"] for b in boxes])) > 1
+ if crosspage:
+ boxes = self.sort_X_firstly(boxes, colwm / 2, False)
+ else:
+ boxes = self.sort_C_firstly(boxes, colwm / 2)
+ boxes[0]["cn"] = 0
+ cols = [[boxes[0]]]
+ right = boxes[0]["x1"]
+ for b in boxes[1:]:
+ b["cn"] = len(cols) - 1
+ lst_c = cols[-1]
+ if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
+ "page_number"]) \
+ or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
+ right = b["x1"]
+ b["cn"] += 1
+ cols.append([b])
+ continue
+ right = (right + b["x1"]) / 2.
+ cols[-1].append(b)
+
+ tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
+ for b in boxes:
+ tbl[b["rn"]][b["cn"]].append(b)
+
+ if len(rows) >= 4:
+ # remove single in column
+ j = 0
+ while j < len(tbl[0]):
+ e, ii = 0, 0
+ for i in range(len(tbl)):
+ if tbl[i][j]:
+ e += 1
+ ii = i
+ if e > 1:
+ break
+ if e > 1:
+ j += 1
+ continue
+ f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
+ [j - 1][0].get("text")) or j == 0
+ ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
+ [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
+ if f and ff:
+ j += 1
+ continue
+ bx = tbl[ii][j][0]
+ logging.debug("Relocate column single: " + bx["text"])
+ # j column only has one value
+ left, right = 100000, 100000
+ if j > 0 and not f:
+ for i in range(len(tbl)):
+ if tbl[i][j - 1]:
+ left = min(left, np.min(
+ [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
+ if j + 1 < len(tbl[0]) and not ff:
+ for i in range(len(tbl)):
+ if tbl[i][j + 1]:
+ right = min(right, np.min(
+ [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
+ assert left < 100000 or right < 100000
+ if left < right:
+ for jj in range(j, len(tbl[0])):
+ for i in range(len(tbl)):
+ for a in tbl[i][jj]:
+ a["cn"] -= 1
+ if tbl[ii][j - 1]:
+ tbl[ii][j - 1].extend(tbl[ii][j])
+ else:
+ tbl[ii][j - 1] = tbl[ii][j]
+ for i in range(len(tbl)):
+ tbl[i].pop(j)
+
+ else:
+ for jj in range(j + 1, len(tbl[0])):
+ for i in range(len(tbl)):
+ for a in tbl[i][jj]:
+ a["cn"] -= 1
+ if tbl[ii][j + 1]:
+ tbl[ii][j + 1].extend(tbl[ii][j])
+ else:
+ tbl[ii][j + 1] = tbl[ii][j]
+ for i in range(len(tbl)):
+ tbl[i].pop(j)
+ cols.pop(j)
+ assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
+ len(cols), len(tbl[0]))
+
+ if len(cols) >= 4:
+ # remove single in row
+ i = 0
+ while i < len(tbl):
+ e, jj = 0, 0
+ for j in range(len(tbl[i])):
+ if tbl[i][j]:
+ e += 1
+ jj = j
+ if e > 1:
+ break
+ if e > 1:
+ i += 1
+ continue
+ f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
+ [jj][0].get("text")) or i == 0
+ ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
+ [jj][0].get("text")) or i + 1 >= len(tbl)
+ if f and ff:
+ i += 1
+ continue
+
+ bx = tbl[i][jj][0]
+ logging.debug("Relocate row single: " + bx["text"])
+ # i row only has one value
+ up, down = 100000, 100000
+ if i > 0 and not f:
+ for j in range(len(tbl[i - 1])):
+ if tbl[i - 1][j]:
+ up = min(up, np.min(
+ [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
+ if i + 1 < len(tbl) and not ff:
+ for j in range(len(tbl[i + 1])):
+ if tbl[i + 1][j]:
+ down = min(down, np.min(
+ [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
+ assert up < 100000 or down < 100000
+ if up < down:
+ for ii in range(i, len(tbl)):
+ for j in range(len(tbl[ii])):
+ for a in tbl[ii][j]:
+ a["rn"] -= 1
+ if tbl[i - 1][jj]:
+ tbl[i - 1][jj].extend(tbl[i][jj])
+ else:
+ tbl[i - 1][jj] = tbl[i][jj]
+ tbl.pop(i)
+
+ else:
+ for ii in range(i + 1, len(tbl)):
+ for j in range(len(tbl[ii])):
+ for a in tbl[ii][j]:
+ a["rn"] -= 1
+ if tbl[i + 1][jj]:
+ tbl[i + 1][jj].extend(tbl[i][jj])
+ else:
+ tbl[i + 1][jj] = tbl[i][jj]
+ tbl.pop(i)
+ rows.pop(i)
+
+ # which rows are headers
+ hdset = set([])
+ for i in range(len(tbl)):
+ cnt, h = 0, 0
+ for j, arr in enumerate(tbl[i]):
+ if not arr:
+ continue
+ cnt += 1
+ if max_type == "Nu" and arr[0]["btype"] == "Nu":
+ continue
+ if any([a.get("H") for a in arr]) \
+ or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
+ h += 1
+ if h / cnt > 0.5:
+ hdset.add(i)
+
+ if html:
+ return [self.__html_table(cap, hdset,
+ self.__cal_spans(boxes, rows,
+ cols, tbl, True)
+ )]
+
+ return self.__desc_table(cap, hdset,
+ self.__cal_spans(boxes, rows, cols, tbl, False),
+ is_english)
+
+ def __html_table(self, cap, hdset, tbl):
+ # constrcut HTML
+ html = ""
+ if cap:
+ html += f"{cap}"
+ for i in range(len(tbl)):
+ row = ""
+ txts = []
+ for j, arr in enumerate(tbl[i]):
+ if arr is None:
+ continue
+ if not arr:
+ row += " | " if i not in hdset else " | "
+ continue
+ txt = ""
+ if arr:
+ h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
+ txt = "".join([c["text"]
+ for c in self.sort_Y_firstly(arr, h)])
+ txts.append(txt)
+ sp = ""
+ if arr[0].get("colspan"):
+ sp = "colspan={}".format(arr[0]["colspan"])
+ if arr[0].get("rowspan"):
+ sp += " rowspan={}".format(arr[0]["rowspan"])
+ if i in hdset:
+ row += f"" + txt + " | "
+ else:
+ row += f"" + txt + " | "
+
+ if i in hdset:
+ if all([t in hdset for t in txts]):
+ continue
+ for t in txts:
+ hdset.add(t)
+
+ if row != "
":
+ row += "
"
+ else:
+ row = ""
+ html += "\n" + row
+ html += "\n
"
+ return html
+
+ def __desc_table(self, cap, hdr_rowno, tbl, is_english):
+ # get text of every colomn in header row to become header text
+ clmno = len(tbl[0])
+ rowno = len(tbl)
+ headers = {}
+ hdrset = set()
+ lst_hdr = []
+ de = "的" if not is_english else " for "
+ for r in sorted(list(hdr_rowno)):
+ headers[r] = ["" for _ in range(clmno)]
+ for i in range(clmno):
+ if not tbl[r][i]:
+ continue
+ txt = "".join([a["text"].strip() for a in tbl[r][i]])
+ headers[r][i] = txt
+ hdrset.add(txt)
+ if all([not t for t in headers[r]]):
+ del headers[r]
+ hdr_rowno.remove(r)
+ continue
+ for j in range(clmno):
+ if headers[r][j]:
+ continue
+ if j >= len(lst_hdr):
+ break
+ headers[r][j] = lst_hdr[j]
+ lst_hdr = headers[r]
+ for i in range(rowno):
+ if i not in hdr_rowno:
+ continue
+ for j in range(i + 1, rowno):
+ if j not in hdr_rowno:
+ break
+ for k in range(clmno):
+ if not headers[j - 1][k]:
+ continue
+ if headers[j][k].find(headers[j - 1][k]) >= 0:
+ continue
+ if len(headers[j][k]) > len(headers[j - 1][k]):
+ headers[j][k] += (de if headers[j][k]
+ else "") + headers[j - 1][k]
+ else:
+ headers[j][k] = headers[j - 1][k] \
+ + (de if headers[j - 1][k] else "") \
+ + headers[j][k]
+
+ logging.debug(
+ f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
+ row_txt = []
+ for i in range(rowno):
+ if i in hdr_rowno:
+ continue
+ rtxt = []
+
+ def append(delimer):
+ nonlocal rtxt, row_txt
+ rtxt = delimer.join(rtxt)
+ if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
+ row_txt[-1] += "\n" + rtxt
+ else:
+ row_txt.append(rtxt)
+
+ r = 0
+ if len(headers.items()):
+ _arr = [(i - r, r) for r, _ in headers.items() if r < i]
+ if _arr:
+ _, r = min(_arr, key=lambda x: x[0])
+
+ if r not in headers and clmno <= 2:
+ for j in range(clmno):
+ if not tbl[i][j]:
+ continue
+ txt = "".join([a["text"].strip() for a in tbl[i][j]])
+ if txt:
+ rtxt.append(txt)
+ if rtxt:
+ append(":")
+ continue
+
+ for j in range(clmno):
+ if not tbl[i][j]:
+ continue
+ txt = "".join([a["text"].strip() for a in tbl[i][j]])
+ if not txt:
+ continue
+ ctt = headers[r][j] if r in headers else ""
+ if ctt:
+ ctt += ":"
+ ctt += txt
+ if ctt:
+ rtxt.append(ctt)
+
+ if rtxt:
+ row_txt.append("; ".join(rtxt))
+
+ if cap:
+ if is_english:
+ from_ = " in "
+ else:
+ from_ = "来自"
+ row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
+ return row_txt
+
+ def __cal_spans(self, boxes, rows, cols, tbl, html=True):
+ # caculate span
+ clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
+ for cln in cols]
+ crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
+ for cln in cols]
+ rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
+ for row in rows]
+ rbtm = [np.mean([c.get("R_btm", c["bottom"])
+ for c in row]) for row in rows]
+ for b in boxes:
+ if "SP" not in b:
+ continue
+ b["colspan"] = [b["cn"]]
+ b["rowspan"] = [b["rn"]]
+ # col span
+ for j in range(0, len(clft)):
+ if j == b["cn"]:
+ continue
+ if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
+ continue
+ if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
+ continue
+ b["colspan"].append(j)
+ # row span
+ for j in range(0, len(rtop)):
+ if j == b["rn"]:
+ continue
+ if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
+ continue
+ if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
+ continue
+ b["rowspan"].append(j)
+
+ def join(arr):
+ if not arr:
+ return ""
+ return "".join([t["text"] for t in arr])
+
+ # rm the spaning cells
+ for i in range(len(tbl)):
+ for j, arr in enumerate(tbl[i]):
+ if not arr:
+ continue
+ if all(["rowspan" not in a and "colspan" not in a for a in arr]):
+ continue
+ rowspan, colspan = [], []
+ for a in arr:
+ if isinstance(a.get("rowspan", 0), list):
+ rowspan.extend(a["rowspan"])
+ if isinstance(a.get("colspan", 0), list):
+ colspan.extend(a["colspan"])
+ rowspan, colspan = set(rowspan), set(colspan)
+ if len(rowspan) < 2 and len(colspan) < 2:
+ for a in arr:
+ if "rowspan" in a:
+ del a["rowspan"]
+ if "colspan" in a:
+ del a["colspan"]
+ continue
+ rowspan, colspan = sorted(rowspan), sorted(colspan)
+ rowspan = list(range(rowspan[0], rowspan[-1] + 1))
+ colspan = list(range(colspan[0], colspan[-1] + 1))
+ assert i in rowspan, rowspan
+ assert j in colspan, colspan
+ arr = []
+ for r in rowspan:
+ for c in colspan:
+ arr_txt = join(arr)
+ if tbl[r][c] and join(tbl[r][c]) != arr_txt:
+ arr.extend(tbl[r][c])
+ tbl[r][c] = None if html else arr
+ for a in arr:
+ if len(rowspan) > 1:
+ a["rowspan"] = len(rowspan)
+ elif "rowspan" in a:
+ del a["rowspan"]
+ if len(colspan) > 1:
+ a["colspan"] = len(colspan)
+ elif "colspan" in a:
+ del a["colspan"]
+ tbl[rowspan[0]][colspan[0]] = arr
+
+ return tbl
+
diff --git a/deepdoc/visual/__init__.py b/deepdoc/visual/__init__.py
deleted file mode 100644
index e53762a90..000000000
--- a/deepdoc/visual/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .ocr import OCR
-from .recognizer import Recognizer
\ No newline at end of file
diff --git a/deepdoc/visual/recognizer.py b/deepdoc/visual/recognizer.py
deleted file mode 100644
index 09ccbb34a..000000000
--- a/deepdoc/visual/recognizer.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import os
-import onnxruntime as ort
-from huggingface_hub import snapshot_download
-
-from .operators import *
-from rag.settings import cron_logger
-
-
-class Recognizer(object):
- def __init__(self, label_list, task_name, model_dir=None):
- """
- If you have trouble downloading HuggingFace models, -_^ this might help!!
-
- For Linux:
- export HF_ENDPOINT=https://hf-mirror.com
-
- For Windows:
- Good luck
- ^_-
-
- """
- if not model_dir:
- model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
-
- model_file_path = os.path.join(model_dir, task_name + ".onnx")
- if not os.path.exists(model_file_path):
- raise ValueError("not find model file path {}".format(
- model_file_path))
- if ort.get_device() == "GPU":
- self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
- else:
- self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
- self.label_list = label_list
-
- def create_inputs(self, imgs, im_info):
- """generate input for different model type
- Args:
- imgs (list(numpy)): list of images (np.ndarray)
- im_info (list(dict)): list of image info
- Returns:
- inputs (dict): input of model
- """
- inputs = {}
-
- im_shape = []
- scale_factor = []
- if len(imgs) == 1:
- inputs['image'] = np.array((imgs[0],)).astype('float32')
- inputs['im_shape'] = np.array(
- (im_info[0]['im_shape'],)).astype('float32')
- inputs['scale_factor'] = np.array(
- (im_info[0]['scale_factor'],)).astype('float32')
- return inputs
-
- for e in im_info:
- im_shape.append(np.array((e['im_shape'],)).astype('float32'))
- scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
-
- inputs['im_shape'] = np.concatenate(im_shape, axis=0)
- inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
-
- imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
- max_shape_h = max([e[0] for e in imgs_shape])
- max_shape_w = max([e[1] for e in imgs_shape])
- padding_imgs = []
- for img in imgs:
- im_c, im_h, im_w = img.shape[:]
- padding_im = np.zeros(
- (im_c, max_shape_h, max_shape_w), dtype=np.float32)
- padding_im[:, :im_h, :im_w] = img
- padding_imgs.append(padding_im)
- inputs['image'] = np.stack(padding_imgs, axis=0)
- return inputs
-
- def preprocess(self, image_list):
- preprocess_ops = []
- for op_info in [
- {'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
- {'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
- {'type': 'Permute'},
- {'stride': 32, 'type': 'PadStride'}
- ]:
- new_op_info = op_info.copy()
- op_type = new_op_info.pop('type')
- preprocess_ops.append(eval(op_type)(**new_op_info))
-
- inputs = []
- for im_path in image_list:
- im, im_info = preprocess(im_path, preprocess_ops)
- inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
- return inputs
-
-
- def __call__(self, image_list, thr=0.7, batch_size=16):
- res = []
- imgs = []
- for i in range(len(image_list)):
- if not isinstance(image_list[i], np.ndarray):
- imgs.append(np.array(image_list[i]))
- else: imgs.append(image_list[i])
-
- batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
- for i in range(batch_loop_cnt):
- start_index = i * batch_size
- end_index = min((i + 1) * batch_size, len(imgs))
- batch_image_list = imgs[start_index:end_index]
- inputs = self.preprocess(batch_image_list)
- for ins in inputs:
- bb = []
- for b in self.ort_sess.run(None, ins)[0]:
- clsid, bbox, score = int(b[0]), b[2:], b[1]
- if score < thr:
- continue
- if clsid >= len(self.label_list):
- cron_logger.warning(f"bad category id")
- continue
- bb.append({
- "type": self.label_list[clsid].lower(),
- "bbox": [float(t) for t in bbox.tolist()],
- "score": float(score)
- })
- res.append(bb)
-
- #seeit.save_results(image_list, res, self.label_list, threshold=thr)
-
- return res
diff --git a/rag/svr/task_broker.py b/rag/svr/task_broker.py
index 1204713d6..cd08b9f45 100644
--- a/rag/svr/task_broker.py
+++ b/rag/svr/task_broker.py
@@ -21,7 +21,7 @@ from datetime import datetime
from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
-from deepdoc.parser import HuParser
+from deepdoc.parser import PdfParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm
@@ -80,7 +80,7 @@ def dispatch():
tsks = []
if r["type"] == FileType.PDF.value:
- pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
+ pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for s,e in r["parser_config"].get("pages", [(0,100000)]):
e = min(e, pages)
for p in range(s, e, 10):