mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Optimize ocr (#5297)
### What problem does this PR solve? Introduced OCR.recognize_batch ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
df3d0f61bd
commit
db42d0e0ae
@ -17,6 +17,7 @@
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
|
||||
import xgboost as xgb
|
||||
from io import BytesIO
|
||||
@ -277,7 +278,11 @@ class RAGFlowPdfParser:
|
||||
b["SP"] = ii
|
||||
|
||||
def __ocr(self, pagenum, img, chars, ZM=3):
|
||||
start = timer()
|
||||
bxs = self.ocr.detect(np.array(img))
|
||||
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
||||
|
||||
start = timer()
|
||||
if not bxs:
|
||||
self.boxes.append([])
|
||||
return
|
||||
@ -308,14 +313,22 @@ class RAGFlowPdfParser:
|
||||
else:
|
||||
bxs[ii]["text"] += c["text"]
|
||||
|
||||
logging.info(f"__ocr sorting {len(chars)} chars cost {timer() - start}s")
|
||||
start = timer()
|
||||
boxes_to_reg = []
|
||||
img_np = np.array(img)
|
||||
for b in bxs:
|
||||
if not b["text"]:
|
||||
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
||||
ZM, b["top"] * ZM, b["bottom"] * ZM
|
||||
b["text"] = self.ocr.recognize(np.array(img),
|
||||
np.array([[left, top], [right, top], [right, bott], [left, bott]],
|
||||
dtype=np.float32))
|
||||
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
||||
boxes_to_reg.append(b)
|
||||
del b["txt"]
|
||||
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg])
|
||||
for i in range(len(boxes_to_reg)):
|
||||
boxes_to_reg[i]["text"] = texts[i]
|
||||
del boxes_to_reg[i]["box_image"]
|
||||
logging.info(f"__ocr recognize {len(bxs)} boxes cost {timer() - start}s")
|
||||
bxs = [b for b in bxs if b["text"]]
|
||||
if self.mean_height[-1] == 0:
|
||||
self.mean_height[-1] = np.median([b["bottom"] - b["top"]
|
||||
@ -951,6 +964,7 @@ class RAGFlowPdfParser:
|
||||
self.page_cum_height = [0]
|
||||
self.page_layout = []
|
||||
self.page_from = page_from
|
||||
start = timer()
|
||||
try:
|
||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
@ -965,6 +979,7 @@ class RAGFlowPdfParser:
|
||||
self.total_page = len(self.pdf.pages)
|
||||
except Exception:
|
||||
logging.exception("RAGFlowPdfParser __images__")
|
||||
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||
|
||||
self.outlines = []
|
||||
try:
|
||||
@ -994,7 +1009,7 @@ class RAGFlowPdfParser:
|
||||
else:
|
||||
self.is_english = False
|
||||
|
||||
# st = timer()
|
||||
start = timer()
|
||||
for i, img in enumerate(self.page_images):
|
||||
chars = self.page_chars[i] if not self.is_english else []
|
||||
self.mean_height.append(
|
||||
@ -1016,7 +1031,7 @@ class RAGFlowPdfParser:
|
||||
self.__ocr(i + 1, img, chars, zoomin)
|
||||
if callback and i % 6 == 5:
|
||||
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||
# print("OCR:", timer()-st)
|
||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||
|
||||
if not self.is_english and not any(
|
||||
[c for c in self.page_chars]) and self.boxes:
|
||||
|
@ -620,6 +620,16 @@ class OCR(object):
|
||||
return ""
|
||||
return text
|
||||
|
||||
def recognize_batch(self, img_list):
|
||||
rec_res, elapse = self.text_recognizer(img_list)
|
||||
texts = []
|
||||
for i in range(len(rec_res)):
|
||||
text, score = rec_res[i]
|
||||
if score < self.drop_score:
|
||||
text = ""
|
||||
texts.append(text)
|
||||
return texts
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||
import random
|
||||
import sys
|
||||
from api.utils.log_utils import initRootLogger
|
||||
from api.utils.log_utils import initRootLogger, get_project_base_directory
|
||||
from graphrag.general.index import WithCommunity, WithResolution, Dealer
|
||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
@ -42,6 +42,7 @@ from io import BytesIO
|
||||
from multiprocessing.context import TimeoutError
|
||||
from timeit import default_timer as timer
|
||||
import tracemalloc
|
||||
import signal
|
||||
|
||||
import numpy as np
|
||||
from peewee import DoesNotExist
|
||||
@ -96,6 +97,35 @@ DONE_TASKS = 0
|
||||
FAILED_TASKS = 0
|
||||
CURRENT_TASK = None
|
||||
|
||||
tracemalloc_started = False
|
||||
|
||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||
def start_tracemalloc_and_snapshot(signum, frame):
|
||||
global tracemalloc_started
|
||||
if not tracemalloc_started:
|
||||
logging.info("got SIGUSR1, start tracemalloc")
|
||||
tracemalloc.start()
|
||||
tracemalloc_started = True
|
||||
else:
|
||||
logging.info("got SIGUSR1, tracemalloc is already running")
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
snapshot.dump(snapshot_file)
|
||||
logging.info(f"taken snapshot {snapshot_file}")
|
||||
|
||||
# SIGUSR2 handler: stop tracemalloc
|
||||
def stop_tracemalloc(signum, frame):
|
||||
global tracemalloc_started
|
||||
if tracemalloc_started:
|
||||
logging.info("go SIGUSR2, stop tracemalloc")
|
||||
tracemalloc.stop()
|
||||
tracemalloc_started = False
|
||||
else:
|
||||
logging.info("got SIGUSR2, tracemalloc not running")
|
||||
|
||||
class TaskCanceledException(Exception):
|
||||
def __init__(self, msg):
|
||||
@ -712,26 +742,18 @@ def main():
|
||||
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
|
||||
settings.init_settings()
|
||||
print_rag_settings()
|
||||
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
||||
TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
|
||||
if TRACE_MALLOC_ENABLED:
|
||||
start_tracemalloc_and_snapshot(None, None)
|
||||
|
||||
background_thread = threading.Thread(target=report_status)
|
||||
background_thread.daemon = True
|
||||
background_thread.start()
|
||||
|
||||
TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
|
||||
TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
|
||||
if TRACE_MALLOC_DELTA > 0:
|
||||
if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
|
||||
TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
|
||||
tracemalloc.start()
|
||||
snapshot1 = tracemalloc.take_snapshot()
|
||||
while True:
|
||||
handle_task()
|
||||
num_tasks = DONE_TASKS + FAILED_TASKS
|
||||
if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
|
||||
snapshot2 = tracemalloc.take_snapshot()
|
||||
analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
|
||||
snapshot1 = snapshot2
|
||||
snapshot2 = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user