mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-04-21 13:40:00 +08:00
Upgrades Document Layout Analysis model. (#4054)
### What problem does this PR solve? #4052 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
b5e4a5563c
commit
ce1e855328
@ -247,8 +247,8 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
task["progress"] = 0.0
|
||||
|
||||
prev_tasks = TaskService.get_tasks(doc["id"])
|
||||
ck_num = 0
|
||||
if prev_tasks:
|
||||
ck_num = 0
|
||||
for task in tsks:
|
||||
ck_num += reuse_prev_task_chunks(task, prev_tasks, chunking_config)
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
@ -258,7 +258,7 @@ def queue_tasks(doc: dict, bucket: str, name: str):
|
||||
chunk_ids.extend(task["chunk_ids"].split())
|
||||
if chunk_ids:
|
||||
settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(chunking_config["tenant_id"]), chunking_config["kb_id"])
|
||||
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
||||
DocumentService.update_by_id(doc["id"], {"chunk_num": ck_num})
|
||||
|
||||
bulk_insert_into_db(Task, tsks, True)
|
||||
DocumentService.begin2parse(doc["id"])
|
||||
|
@ -16,6 +16,8 @@
|
||||
"content_with_weight": {"type": "varchar", "default": ""},
|
||||
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"content_sm_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"authors_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"authors_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"page_num_int": {"type": "varchar", "default": ""},
|
||||
"top_int": {"type": "varchar", "default": ""},
|
||||
"position_int": {"type": "varchar", "default": ""},
|
||||
|
@ -15,9 +15,10 @@ import pdfplumber
|
||||
|
||||
from .ocr import OCR
|
||||
from .recognizer import Recognizer
|
||||
from .layout_recognizer import LayoutRecognizer
|
||||
from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
from PIL import Image
|
||||
import os
|
||||
|
@ -14,11 +14,14 @@ import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from deepdoc.vision import Recognizer
|
||||
from deepdoc.vision.operators import nms
|
||||
|
||||
|
||||
class LayoutRecognizer(Recognizer):
|
||||
@ -149,3 +152,88 @@ class LayoutRecognizer(Recognizer):
|
||||
|
||||
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
||||
return ocr_res, page_layout
|
||||
|
||||
|
||||
class LayoutRecognizer4YOLOv10(LayoutRecognizer):
|
||||
labels = [
|
||||
"title",
|
||||
"Text",
|
||||
"Reference",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Table caption",
|
||||
"Equation",
|
||||
"Figure caption",
|
||||
]
|
||||
|
||||
def __init__(self, domain):
|
||||
domain = "layout"
|
||||
super().__init__(domain)
|
||||
self.auto = False
|
||||
self.scaleFill = False
|
||||
self.scaleup = True
|
||||
self.stride = 32
|
||||
self.center = True
|
||||
|
||||
def preprocess(self, image_list):
|
||||
inputs = []
|
||||
new_shape = self.input_shape # height, width
|
||||
for img in image_list:
|
||||
shape = img.shape[:2]# current shape [height, width]
|
||||
# Scale ratio (new / old)
|
||||
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||||
# Compute padding
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
||||
dw /= 2 # divide padding into 2 sides
|
||||
dh /= 2
|
||||
ww, hh = new_unpad
|
||||
img = np.array(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).astype(np.float32)
|
||||
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
|
||||
) # add border
|
||||
img /= 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img[np.newaxis, :, :, :].astype(np.float32)
|
||||
inputs.append({self.input_names[0]: img, "scale_factor": [shape[1]/ww, shape[0]/hh, dw, dh]})
|
||||
|
||||
return inputs
|
||||
|
||||
def postprocess(self, boxes, inputs, thr):
|
||||
thr = 0.08
|
||||
boxes = np.squeeze(boxes)
|
||||
scores = boxes[:, 4]
|
||||
boxes = boxes[scores > thr, :]
|
||||
scores = scores[scores > thr]
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
class_ids = boxes[:, -1].astype(int)
|
||||
boxes = boxes[:, :4]
|
||||
boxes[:, 0] -= inputs["scale_factor"][2]
|
||||
boxes[:, 2] -= inputs["scale_factor"][2]
|
||||
boxes[:, 1] -= inputs["scale_factor"][3]
|
||||
boxes[:, 3] -= inputs["scale_factor"][3]
|
||||
input_shape = np.array([inputs["scale_factor"][0], inputs["scale_factor"][1], inputs["scale_factor"][0],
|
||||
inputs["scale_factor"][1]])
|
||||
boxes = np.multiply(boxes, input_shape, dtype=np.float32)
|
||||
|
||||
unique_class_ids = np.unique(class_ids)
|
||||
indices = []
|
||||
for class_id in unique_class_ids:
|
||||
class_indices = np.where(class_ids == class_id)[0]
|
||||
class_boxes = boxes[class_indices, :]
|
||||
class_scores = scores[class_indices]
|
||||
class_keep_boxes = nms(class_boxes, class_scores, 0.45)
|
||||
indices.extend(class_indices[class_keep_boxes])
|
||||
|
||||
return [{
|
||||
"type": self.label_list[class_ids[i]].lower(),
|
||||
"bbox": [float(t) for t in boxes[i].tolist()],
|
||||
"score": float(scores[i])
|
||||
} for i in indices]
|
||||
|
||||
|
@ -709,3 +709,29 @@ def preprocess(im, preprocess_ops):
|
||||
for operator in preprocess_ops:
|
||||
im, im_info = operator(im, im_info)
|
||||
return im, im_info
|
||||
|
||||
|
||||
def nms(bboxes, scores, iou_thresh):
|
||||
import numpy as np
|
||||
x1 = bboxes[:, 0]
|
||||
y1 = bboxes[:, 1]
|
||||
x2 = bboxes[:, 2]
|
||||
y2 = bboxes[:, 3]
|
||||
areas = (y2 - y1) * (x2 - x1)
|
||||
|
||||
indices = []
|
||||
index = scores.argsort()[::-1]
|
||||
while index.size > 0:
|
||||
i = index[0]
|
||||
indices.append(i)
|
||||
x11 = np.maximum(x1[i], x1[index[1:]])
|
||||
y11 = np.maximum(y1[i], y1[index[1:]])
|
||||
x22 = np.minimum(x2[i], x2[index[1:]])
|
||||
y22 = np.minimum(y2[i], y2[index[1:]])
|
||||
w = np.maximum(0, x22 - x11 + 1)
|
||||
h = np.maximum(0, y22 - y11 + 1)
|
||||
overlaps = w * h
|
||||
ious = overlaps / (areas[i] + areas[index[1:]] - overlaps)
|
||||
idx = np.where(ious <= iou_thresh)[0]
|
||||
index = index[idx + 1]
|
||||
return indices
|
||||
|
Loading…
x
Reference in New Issue
Block a user