mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-06-04 11:24:00 +08:00
Optimize ocr (#5297)
### What problem does this PR solve? Introduced OCR.recognize_batch ### Type of change - [x] Performance Improvement
This commit is contained in:
parent
df3d0f61bd
commit
db42d0e0ae
@ -17,6 +17,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -277,7 +278,11 @@ class RAGFlowPdfParser:
|
|||||||
b["SP"] = ii
|
b["SP"] = ii
|
||||||
|
|
||||||
def __ocr(self, pagenum, img, chars, ZM=3):
|
def __ocr(self, pagenum, img, chars, ZM=3):
|
||||||
|
start = timer()
|
||||||
bxs = self.ocr.detect(np.array(img))
|
bxs = self.ocr.detect(np.array(img))
|
||||||
|
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
||||||
|
|
||||||
|
start = timer()
|
||||||
if not bxs:
|
if not bxs:
|
||||||
self.boxes.append([])
|
self.boxes.append([])
|
||||||
return
|
return
|
||||||
@ -308,14 +313,22 @@ class RAGFlowPdfParser:
|
|||||||
else:
|
else:
|
||||||
bxs[ii]["text"] += c["text"]
|
bxs[ii]["text"] += c["text"]
|
||||||
|
|
||||||
|
logging.info(f"__ocr sorting {len(chars)} chars cost {timer() - start}s")
|
||||||
|
start = timer()
|
||||||
|
boxes_to_reg = []
|
||||||
|
img_np = np.array(img)
|
||||||
for b in bxs:
|
for b in bxs:
|
||||||
if not b["text"]:
|
if not b["text"]:
|
||||||
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
||||||
ZM, b["top"] * ZM, b["bottom"] * ZM
|
ZM, b["top"] * ZM, b["bottom"] * ZM
|
||||||
b["text"] = self.ocr.recognize(np.array(img),
|
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
||||||
np.array([[left, top], [right, top], [right, bott], [left, bott]],
|
boxes_to_reg.append(b)
|
||||||
dtype=np.float32))
|
|
||||||
del b["txt"]
|
del b["txt"]
|
||||||
|
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg])
|
||||||
|
for i in range(len(boxes_to_reg)):
|
||||||
|
boxes_to_reg[i]["text"] = texts[i]
|
||||||
|
del boxes_to_reg[i]["box_image"]
|
||||||
|
logging.info(f"__ocr recognize {len(bxs)} boxes cost {timer() - start}s")
|
||||||
bxs = [b for b in bxs if b["text"]]
|
bxs = [b for b in bxs if b["text"]]
|
||||||
if self.mean_height[-1] == 0:
|
if self.mean_height[-1] == 0:
|
||||||
self.mean_height[-1] = np.median([b["bottom"] - b["top"]
|
self.mean_height[-1] = np.median([b["bottom"] - b["top"]
|
||||||
@ -951,6 +964,7 @@ class RAGFlowPdfParser:
|
|||||||
self.page_cum_height = [0]
|
self.page_cum_height = [0]
|
||||||
self.page_layout = []
|
self.page_layout = []
|
||||||
self.page_from = page_from
|
self.page_from = page_from
|
||||||
|
start = timer()
|
||||||
try:
|
try:
|
||||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||||
@ -965,6 +979,7 @@ class RAGFlowPdfParser:
|
|||||||
self.total_page = len(self.pdf.pages)
|
self.total_page = len(self.pdf.pages)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("RAGFlowPdfParser __images__")
|
logging.exception("RAGFlowPdfParser __images__")
|
||||||
|
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||||
|
|
||||||
self.outlines = []
|
self.outlines = []
|
||||||
try:
|
try:
|
||||||
@ -994,7 +1009,7 @@ class RAGFlowPdfParser:
|
|||||||
else:
|
else:
|
||||||
self.is_english = False
|
self.is_english = False
|
||||||
|
|
||||||
# st = timer()
|
start = timer()
|
||||||
for i, img in enumerate(self.page_images):
|
for i, img in enumerate(self.page_images):
|
||||||
chars = self.page_chars[i] if not self.is_english else []
|
chars = self.page_chars[i] if not self.is_english else []
|
||||||
self.mean_height.append(
|
self.mean_height.append(
|
||||||
@ -1016,7 +1031,7 @@ class RAGFlowPdfParser:
|
|||||||
self.__ocr(i + 1, img, chars, zoomin)
|
self.__ocr(i + 1, img, chars, zoomin)
|
||||||
if callback and i % 6 == 5:
|
if callback and i % 6 == 5:
|
||||||
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||||
# print("OCR:", timer()-st)
|
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||||
|
|
||||||
if not self.is_english and not any(
|
if not self.is_english and not any(
|
||||||
[c for c in self.page_chars]) and self.boxes:
|
[c for c in self.page_chars]) and self.boxes:
|
||||||
|
@ -620,6 +620,16 @@ class OCR(object):
|
|||||||
return ""
|
return ""
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def recognize_batch(self, img_list):
|
||||||
|
rec_res, elapse = self.text_recognizer(img_list)
|
||||||
|
texts = []
|
||||||
|
for i in range(len(rec_res)):
|
||||||
|
text, score = rec_res[i]
|
||||||
|
if score < self.drop_score:
|
||||||
|
text = ""
|
||||||
|
texts.append(text)
|
||||||
|
return texts
|
||||||
|
|
||||||
def __call__(self, img, cls=True):
|
def __call__(self, img, cls=True):
|
||||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
from api.utils.log_utils import initRootLogger
|
from api.utils.log_utils import initRootLogger, get_project_base_directory
|
||||||
from graphrag.general.index import WithCommunity, WithResolution, Dealer
|
from graphrag.general.index import WithCommunity, WithResolution, Dealer
|
||||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||||
@ -42,6 +42,7 @@ from io import BytesIO
|
|||||||
from multiprocessing.context import TimeoutError
|
from multiprocessing.context import TimeoutError
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import tracemalloc
|
import tracemalloc
|
||||||
|
import signal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from peewee import DoesNotExist
|
from peewee import DoesNotExist
|
||||||
@ -96,6 +97,35 @@ DONE_TASKS = 0
|
|||||||
FAILED_TASKS = 0
|
FAILED_TASKS = 0
|
||||||
CURRENT_TASK = None
|
CURRENT_TASK = None
|
||||||
|
|
||||||
|
tracemalloc_started = False
|
||||||
|
|
||||||
|
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||||
|
def start_tracemalloc_and_snapshot(signum, frame):
|
||||||
|
global tracemalloc_started
|
||||||
|
if not tracemalloc_started:
|
||||||
|
logging.info("got SIGUSR1, start tracemalloc")
|
||||||
|
tracemalloc.start()
|
||||||
|
tracemalloc_started = True
|
||||||
|
else:
|
||||||
|
logging.info("got SIGUSR1, tracemalloc is already running")
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
snapshot_file = f"snapshot_{timestamp}.trace"
|
||||||
|
snapshot_file = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{os.getpid()}_snapshot_{timestamp}.trace"))
|
||||||
|
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
snapshot.dump(snapshot_file)
|
||||||
|
logging.info(f"taken snapshot {snapshot_file}")
|
||||||
|
|
||||||
|
# SIGUSR2 handler: stop tracemalloc
|
||||||
|
def stop_tracemalloc(signum, frame):
|
||||||
|
global tracemalloc_started
|
||||||
|
if tracemalloc_started:
|
||||||
|
logging.info("go SIGUSR2, stop tracemalloc")
|
||||||
|
tracemalloc.stop()
|
||||||
|
tracemalloc_started = False
|
||||||
|
else:
|
||||||
|
logging.info("got SIGUSR2, tracemalloc not running")
|
||||||
|
|
||||||
class TaskCanceledException(Exception):
|
class TaskCanceledException(Exception):
|
||||||
def __init__(self, msg):
|
def __init__(self, msg):
|
||||||
@ -712,26 +742,18 @@ def main():
|
|||||||
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
|
logging.info(f'TaskExecutor: RAGFlow version: {get_ragflow_version()}')
|
||||||
settings.init_settings()
|
settings.init_settings()
|
||||||
print_rag_settings()
|
print_rag_settings()
|
||||||
|
signal.signal(signal.SIGUSR1, start_tracemalloc_and_snapshot)
|
||||||
|
signal.signal(signal.SIGUSR2, stop_tracemalloc)
|
||||||
|
TRACE_MALLOC_ENABLED = int(os.environ.get('TRACE_MALLOC_ENABLED', "0"))
|
||||||
|
if TRACE_MALLOC_ENABLED:
|
||||||
|
start_tracemalloc_and_snapshot(None, None)
|
||||||
|
|
||||||
background_thread = threading.Thread(target=report_status)
|
background_thread = threading.Thread(target=report_status)
|
||||||
background_thread.daemon = True
|
background_thread.daemon = True
|
||||||
background_thread.start()
|
background_thread.start()
|
||||||
|
|
||||||
TRACE_MALLOC_DELTA = int(os.environ.get('TRACE_MALLOC_DELTA', "0"))
|
|
||||||
TRACE_MALLOC_FULL = int(os.environ.get('TRACE_MALLOC_FULL', "0"))
|
|
||||||
if TRACE_MALLOC_DELTA > 0:
|
|
||||||
if TRACE_MALLOC_FULL < TRACE_MALLOC_DELTA:
|
|
||||||
TRACE_MALLOC_FULL = TRACE_MALLOC_DELTA
|
|
||||||
tracemalloc.start()
|
|
||||||
snapshot1 = tracemalloc.take_snapshot()
|
|
||||||
while True:
|
while True:
|
||||||
handle_task()
|
handle_task()
|
||||||
num_tasks = DONE_TASKS + FAILED_TASKS
|
|
||||||
if TRACE_MALLOC_DELTA > 0 and num_tasks > 0 and num_tasks % TRACE_MALLOC_DELTA == 0:
|
|
||||||
snapshot2 = tracemalloc.take_snapshot()
|
|
||||||
analyze_heap(snapshot1, snapshot2, int(num_tasks / TRACE_MALLOC_DELTA), num_tasks % TRACE_MALLOC_FULL == 0)
|
|
||||||
snapshot1 = snapshot2
|
|
||||||
snapshot2 = None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user