Optimize ocr (#5297)

### What problem does this PR solve?

Introduced OCR.recognize_batch

### Type of change

- [x] Performance Improvement
This commit is contained in:
Zhichang Yu 2025-02-24 16:21:55 +08:00 committed by GitHub
parent df3d0f61bd
commit db42d0e0ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 20 deletions

View File

@ -17,6 +17,7 @@
import logging import logging
import os import os
import random import random
from timeit import default_timer as timer
import xgboost as xgb import xgboost as xgb
from io import BytesIO from io import BytesIO
@ -277,7 +278,11 @@ class RAGFlowPdfParser:
b["SP"] = ii b["SP"] = ii
def __ocr(self, pagenum, img, chars, ZM=3): def __ocr(self, pagenum, img, chars, ZM=3):
start = timer()
bxs = self.ocr.detect(np.array(img)) bxs = self.ocr.detect(np.array(img))
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
start = timer()
if not bxs: if not bxs:
self.boxes.append([]) self.boxes.append([])
return return
@ -308,14 +313,22 @@ class RAGFlowPdfParser:
else: else:
bxs[ii]["text"] += c["text"] bxs[ii]["text"] += c["text"]
logging.info(f"__ocr sorting {len(chars)} chars cost {timer() - start}s")
start = timer()
boxes_to_reg = []
img_np = np.array(img)
for b in bxs: for b in bxs:
if not b["text"]: if not b["text"]:
left, right, top, bott = b["x0"] * ZM, b["x1"] * \ left, right, top, bott = b["x0"] * ZM, b["x1"] * \
ZM, b["top"] * ZM, b["bottom"] * ZM ZM, b["top"] * ZM, b["bottom"] * ZM
b["text"] = self.ocr.recognize(np.array(img), b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
np.array([[left, top], [right, top], [right, bott], [left, bott]], boxes_to_reg.append(b)
dtype=np.float32))
del b["txt"] del b["txt"]
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg])
for i in range(len(boxes_to_reg)):
boxes_to_reg[i]["text"] = texts[i]
del boxes_to_reg[i]["box_image"]
logging.info(f"__ocr recognize {len(bxs)} boxes cost {timer() - start}s")
bxs = [b for b in bxs if b["text"]] bxs = [b for b in bxs if b["text"]]
if self.mean_height[-1] == 0: if self.mean_height[-1] == 0:
self.mean_height[-1] = np.median([b["bottom"] - b["top"] self.mean_height[-1] = np.median([b["bottom"] - b["top"]
@ -951,6 +964,7 @@ class RAGFlowPdfParser:
self.page_cum_height = [0] self.page_cum_height = [0]
self.page_layout = [] self.page_layout = []
self.page_from = page_from self.page_from = page_from
start = timer()
try: try:
self.pdf = pdfplumber.open(fnm) if isinstance( self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm)) fnm, str) else pdfplumber.open(BytesIO(fnm))
@ -965,6 +979,7 @@ class RAGFlowPdfParser:
self.total_page = len(self.pdf.pages) self.total_page = len(self.pdf.pages)
except Exception: except Exception:
logging.exception("RAGFlowPdfParser __images__") logging.exception("RAGFlowPdfParser __images__")
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
self.outlines = [] self.outlines = []
try: try:
@ -994,7 +1009,7 @@ class RAGFlowPdfParser:
else: else:
self.is_english = False self.is_english = False
# st = timer() start = timer()
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):
chars = self.page_chars[i] if not self.is_english else [] chars = self.page_chars[i] if not self.is_english else []
self.mean_height.append( self.mean_height.append(
@ -1016,7 +1031,7 @@ class RAGFlowPdfParser:
self.__ocr(i + 1, img, chars, zoomin) self.__ocr(i + 1, img, chars, zoomin)
if callback and i % 6 == 5: if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
# print("OCR:", timer()-st) logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
if not self.is_english and not any( if not self.is_english and not any(
[c for c in self.page_chars]) and self.boxes: [c for c in self.page_chars]) and self.boxes:

View File

@ -620,6 +620,16 @@ class OCR(object):
return "" return ""
return text return text
def recognize_batch(self, img_list):
rec_res, elapse = self.text_recognizer(img_list)
texts = []
for i in range(len(rec_res)):
text, score = rec_res[i]
if score < self.drop_score:
text = ""
texts.append(text)
return texts
def __call__(self, img, cls=True): def __call__(self, img, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0} time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}

View File

@ -18,7 +18,7 @@
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
import random import random
import sys import sys
from api.utils.log_utils import initRootLogger from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import WithCommunity, WithResolution, Dealer from graphrag.general.index import WithCommunity, WithResolution, Dealer
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
@ -42,6 +42,7 @@ from io import BytesIO
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from timeit import default_timer as timer from timeit import default_timer as timer
import tracemalloc import tracemalloc
import signal
import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
@ -96,6 +97,35 @@ DONE_TASKS = 0
FAILED_TASKS = 0 FAILED_TASKS = 0
CURRENT_TASK = None CURRENT_TASK = None
tracemalloc_started = False
# SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame):
global tracemalloc_started
if not tracemalloc_started:
logging.info("got SIGUSR1, start tracemalloc")
tracemalloc.start()
tracemalloc_started = True
else:
logging.info("got SIGUSR1, tracemalloc is already running")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = f"snapshot_{timestamp}.trace"
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
snapshot = tracemalloc.take_snapshot()
snapshot.dump(snapshot_file)
logging.info(f"taken snapshot {snapshot_file}")
# SIGUSR2 handler: stop tracemalloc
def stop_tracemalloc(signum, frame):
global tracemalloc_started
if tracemalloc_started:
logging.info("go SIGUSR2, stop tracemalloc")
tracemalloc.stop()
tracemalloc_started = False
else:
logging.info("got SIGUSR2, tracemalloc not running")
class TaskCanceledException(Exception): class TaskCanceledException(Exception):
def __init__(self, msg): def __init__(self, msg):
@ -712,26 +742,18 @@ def main():
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}') logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
settings.init_settings() settings.init_settings()
print_rag_settings() print_rag_settings()
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
signal.signal(signal.SIGUSR2, stop_tracemalloc)
TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
if TRACE_MALLOC_ENABLED:
start_tracemalloc_and_snapshot(None, None)
background_thread = threading.Thread(target=report_status) background_thread = threading.Thread(target=report_status)
background_thread.daemon = True background_thread.daemon = True
background_thread.start() background_thread.start()
TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
if TRACE_MALLOC_DELTA > 0:
if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
tracemalloc.start()
snapshot1 = tracemalloc.take_snapshot()
while True: while True:
handle_task() handle_task()
num_tasks = DONE_TASKS + FAILED_TASKS
if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
snapshot2 = tracemalloc.take_snapshot()
analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
snapshot1 = snapshot2
snapshot2 = None
if __name__ == "__main__": if __name__ == "__main__":
main() main()