mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-20 13:10:05 +08:00
149 lines
6.0 KiB
Python
149 lines
6.0 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import os
|
||
import re
|
||
from collections import Counter
|
||
from copy import deepcopy
|
||
import numpy as np
|
||
from huggingface_hub import snapshot_download
|
||
|
||
from api.db import ParserType
|
||
from api.utils.file_utils import get_project_base_directory
|
||
from deepdoc.vision import Recognizer
|
||
|
||
|
||
class LayoutRecognizer(Recognizer):
|
||
labels = [
|
||
"_background_",
|
||
"Text",
|
||
"Title",
|
||
"Figure",
|
||
"Figure caption",
|
||
"Table",
|
||
"Table caption",
|
||
"Header",
|
||
"Footer",
|
||
"Reference",
|
||
"Equation",
|
||
]
|
||
def __init__(self, domain):
|
||
try:
|
||
model_dir = snapshot_download(
|
||
repo_id="InfiniFlow/deepdoc",
|
||
local_dir=os.path.join(
|
||
get_project_base_directory(),
|
||
"rag/res/deepdoc"),
|
||
local_files_only=True)
|
||
except Exception as e:
|
||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||
|
||
super().__init__(self.labels, domain, model_dir)#os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||
self.garbage_layouts = ["footer", "header", "reference"]
|
||
|
||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16, drop=True):
|
||
def __is_garbage(b):
|
||
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
||
"(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
|
||
"\\(cid *: *[0-9]+ *\\)"
|
||
]
|
||
return any([re.search(p, b["text"]) for p in patt])
|
||
|
||
layouts = super().__call__(image_list, thr, batch_size)
|
||
# save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
|
||
assert len(image_list) == len(ocr_res)
|
||
# Tag layout type
|
||
boxes = []
|
||
assert len(image_list) == len(layouts)
|
||
garbages = {}
|
||
page_layout = []
|
||
for pn, lts in enumerate(layouts):
|
||
bxs = ocr_res[pn]
|
||
lts = [{"type": b["type"],
|
||
"score": float(b["score"]),
|
||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
||
"page_number": pn,
|
||
} for b in lts]
|
||
lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
|
||
lts = self.layouts_cleanup(bxs, lts)
|
||
page_layout.append(lts)
|
||
|
||
# Tag layout type, layouts are ready
|
||
def findLayout(ty):
|
||
nonlocal bxs, lts, self
|
||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||
i = 0
|
||
while i < len(bxs):
|
||
if bxs[i].get("layout_type"):
|
||
i += 1
|
||
continue
|
||
if __is_garbage(bxs[i]):
|
||
bxs.pop(i)
|
||
continue
|
||
|
||
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
|
||
thr=0.4)
|
||
if ii is None: # belong to nothing
|
||
bxs[i]["layout_type"] = ""
|
||
i += 1
|
||
continue
|
||
lts_[ii]["visited"] = True
|
||
keep_feats = [
|
||
lts_[ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1]*0.9/scale_factor,
|
||
lts_[ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1]*0.1/scale_factor,
|
||
]
|
||
if drop and lts_[ii]["type"] in self.garbage_layouts and not any(keep_feats):
|
||
if lts_[ii]["type"] not in garbages:
|
||
garbages[lts_[ii]["type"]] = []
|
||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||
bxs.pop(i)
|
||
continue
|
||
|
||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||
bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[ii]["type"]!="equation" else "figure"
|
||
i += 1
|
||
|
||
for lt in ["footer", "header", "reference", "figure caption",
|
||
"table caption", "title", "table", "text", "figure", "equation"]:
|
||
findLayout(lt)
|
||
|
||
# add box to figure layouts which has not text box
|
||
for i, lt in enumerate(
|
||
[lt for lt in lts if lt["type"] in ["figure","equation"]]):
|
||
if lt.get("visited"):
|
||
continue
|
||
lt = deepcopy(lt)
|
||
del lt["type"]
|
||
lt["text"] = ""
|
||
lt["layout_type"] = "figure"
|
||
lt["layoutno"] = f"figure-{i}"
|
||
bxs.append(lt)
|
||
|
||
boxes.extend(bxs)
|
||
|
||
ocr_res = boxes
|
||
|
||
garbag_set = set()
|
||
for k in garbages.keys():
|
||
garbages[k] = Counter(garbages[k])
|
||
for g, c in garbages[k].items():
|
||
if c > 1:
|
||
garbag_set.add(g)
|
||
|
||
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
||
return ocr_res, page_layout
|
||
|
||
|
||
|