mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-11 20:39:01 +08:00
Feat: add OCR's muti-gpus and parallel processing support (#5972)
### What problem does this PR solve? Add OCR's muti-gpus and parallel processing support ### Type of change - [x] New Feature (non-breaking change which adds functionality) @yuzhichang I've tried to resolve the comments in #5697. OCR jobs can now be done on both CPU and GPU. ( By the way, I've encountered a “Generate embedding error” issue #5954 that might be due to my outdated GPUs? idk. ) Please review it and give me suggestions. GPU:   CPU: 
This commit is contained in:
parent
8495036ff9
commit
3e19044dee
@ -20,6 +20,7 @@ import random
|
|||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import trio
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -41,7 +42,7 @@ if LOCK_KEY_pdfplumber not in sys.modules:
|
|||||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
class RAGFlowPdfParser:
|
class RAGFlowPdfParser:
|
||||||
def __init__(self):
|
def __init__(self, parallel_devices: int | None = None):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
@ -53,7 +54,13 @@ class RAGFlowPdfParser:
|
|||||||
^_-
|
^_-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.ocr = OCR()
|
|
||||||
|
self.ocr = OCR(parallel_devices = parallel_devices)
|
||||||
|
self.parallel_devices = parallel_devices
|
||||||
|
self.parallel_limiter = None
|
||||||
|
if parallel_devices is not None and parallel_devices > 1:
|
||||||
|
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)]
|
||||||
|
|
||||||
if hasattr(self, "model_speciess"):
|
if hasattr(self, "model_speciess"):
|
||||||
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
||||||
else:
|
else:
|
||||||
@ -63,7 +70,7 @@ class RAGFlowPdfParser:
|
|||||||
self.updown_cnt_mdl = xgb.Booster()
|
self.updown_cnt_mdl = xgb.Booster()
|
||||||
if not settings.LIGHTEN:
|
if not settings.LIGHTEN:
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch.cuda
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -283,9 +290,9 @@ class RAGFlowPdfParser:
|
|||||||
b["H_right"] = spans[ii]["x1"]
|
b["H_right"] = spans[ii]["x1"]
|
||||||
b["SP"] = ii
|
b["SP"] = ii
|
||||||
|
|
||||||
def __ocr(self, pagenum, img, chars, ZM=3):
|
def __ocr(self, pagenum, img, chars, ZM=3, device_id: int | None = None):
|
||||||
start = timer()
|
start = timer()
|
||||||
bxs = self.ocr.detect(np.array(img))
|
bxs = self.ocr.detect(np.array(img), device_id)
|
||||||
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
logging.info(f"__ocr detecting boxes of a image cost ({timer() - start}s)")
|
||||||
|
|
||||||
start = timer()
|
start = timer()
|
||||||
@ -330,7 +337,7 @@ class RAGFlowPdfParser:
|
|||||||
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
b["box_image"] = self.ocr.get_rotate_crop_image(img_np, np.array([[left, top], [right, top], [right, bott], [left, bott]], dtype=np.float32))
|
||||||
boxes_to_reg.append(b)
|
boxes_to_reg.append(b)
|
||||||
del b["txt"]
|
del b["txt"]
|
||||||
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg])
|
texts = self.ocr.recognize_batch([b["box_image"] for b in boxes_to_reg], device_id)
|
||||||
for i in range(len(boxes_to_reg)):
|
for i in range(len(boxes_to_reg)):
|
||||||
boxes_to_reg[i]["text"] = texts[i]
|
boxes_to_reg[i]["text"] = texts[i]
|
||||||
del boxes_to_reg[i]["box_image"]
|
del boxes_to_reg[i]["box_image"]
|
||||||
@ -1022,28 +1029,54 @@ class RAGFlowPdfParser:
|
|||||||
else:
|
else:
|
||||||
self.is_english = False
|
self.is_english = False
|
||||||
|
|
||||||
start = timer()
|
async def __img_ocr(i, id, img, chars, limiter):
|
||||||
for i, img in enumerate(self.page_images):
|
|
||||||
chars = self.page_chars[i] if not self.is_english else []
|
|
||||||
self.mean_height.append(
|
|
||||||
np.median(sorted([c["height"] for c in chars])) if chars else 0
|
|
||||||
)
|
|
||||||
self.mean_width.append(
|
|
||||||
np.median(sorted([c["width"] for c in chars])) if chars else 8
|
|
||||||
)
|
|
||||||
self.page_cum_height.append(img.size[1] / zoomin)
|
|
||||||
j = 0
|
j = 0
|
||||||
while j + 1 < len(chars):
|
while j + 1 < len(chars):
|
||||||
if chars[j]["text"] and chars[j + 1]["text"] \
|
if chars[j]["text"] and chars[j + 1]["text"] \
|
||||||
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
|
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
|
||||||
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
|
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
|
||||||
chars[j]["width"]) / 2:
|
chars[j]["width"]) / 2:
|
||||||
chars[j]["text"] += " "
|
chars[j]["text"] += " "
|
||||||
j += 1
|
j += 1
|
||||||
|
|
||||||
self.__ocr(i + 1, img, chars, zoomin)
|
if limiter:
|
||||||
|
async with limiter:
|
||||||
|
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
|
||||||
|
else:
|
||||||
|
self.__ocr(i + 1, img, chars, zoomin, id)
|
||||||
|
|
||||||
if callback and i % 6 == 5:
|
if callback and i % 6 == 5:
|
||||||
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
||||||
|
|
||||||
|
async def __img_ocr_launcher():
|
||||||
|
def __ocr_preprocess():
|
||||||
|
chars = self.page_chars[i] if not self.is_english else []
|
||||||
|
self.mean_height.append(
|
||||||
|
np.median(sorted([c["height"] for c in chars])) if chars else 0
|
||||||
|
)
|
||||||
|
self.mean_width.append(
|
||||||
|
np.median(sorted([c["width"] for c in chars])) if chars else 8
|
||||||
|
)
|
||||||
|
self.page_cum_height.append(img.size[1] / zoomin)
|
||||||
|
return chars
|
||||||
|
|
||||||
|
if self.parallel_limiter:
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for i, img in enumerate(self.page_images):
|
||||||
|
chars = __ocr_preprocess()
|
||||||
|
|
||||||
|
nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars,
|
||||||
|
self.parallel_limiter[i % self.parallel_devices])
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
for i, img in enumerate(self.page_images):
|
||||||
|
chars = __ocr_preprocess()
|
||||||
|
await __img_ocr(i, 0, img, chars, None)
|
||||||
|
|
||||||
|
start = timer()
|
||||||
|
|
||||||
|
trio.run(__img_ocr_launcher)
|
||||||
|
|
||||||
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
|
||||||
|
|
||||||
if not self.is_english and not any(
|
if not self.is_english and not any(
|
||||||
|
@ -66,10 +66,12 @@ def create_operators(op_param_list, global_config=None):
|
|||||||
return ops
|
return ops
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_dir, nm):
|
def load_model(model_dir, nm, device_id: int | None = None):
|
||||||
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
||||||
|
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
|
||||||
|
|
||||||
global loaded_models
|
global loaded_models
|
||||||
loaded_model = loaded_models.get(model_file_path)
|
loaded_model = loaded_models.get(model_cached_tag)
|
||||||
if loaded_model:
|
if loaded_model:
|
||||||
logging.info(f"load_model {model_file_path} reuses cached model")
|
logging.info(f"load_model {model_file_path} reuses cached model")
|
||||||
return loaded_model
|
return loaded_model
|
||||||
@ -81,7 +83,7 @@ def load_model(model_dir, nm):
|
|||||||
def cuda_is_available():
|
def cuda_is_available():
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() and torch.cuda.device_count() > device_id:
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@ -98,7 +100,7 @@ def load_model(model_dir, nm):
|
|||||||
run_options = ort.RunOptions()
|
run_options = ort.RunOptions()
|
||||||
if cuda_is_available():
|
if cuda_is_available():
|
||||||
cuda_provider_options = {
|
cuda_provider_options = {
|
||||||
"device_id": 0, # Use specific GPU
|
"device_id": device_id, # Use specific GPU
|
||||||
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
|
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory
|
||||||
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
|
"arena_extend_strategy": "kNextPowerOfTwo", # gpu memory allocation strategy
|
||||||
}
|
}
|
||||||
@ -108,7 +110,7 @@ def load_model(model_dir, nm):
|
|||||||
providers=['CUDAExecutionProvider'],
|
providers=['CUDAExecutionProvider'],
|
||||||
provider_options=[cuda_provider_options]
|
provider_options=[cuda_provider_options]
|
||||||
)
|
)
|
||||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:0")
|
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id))
|
||||||
logging.info(f"load_model {model_file_path} uses GPU")
|
logging.info(f"load_model {model_file_path} uses GPU")
|
||||||
else:
|
else:
|
||||||
sess = ort.InferenceSession(
|
sess = ort.InferenceSession(
|
||||||
@ -118,12 +120,12 @@ def load_model(model_dir, nm):
|
|||||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
||||||
logging.info(f"load_model {model_file_path} uses CPU")
|
logging.info(f"load_model {model_file_path} uses CPU")
|
||||||
loaded_model = (sess, run_options)
|
loaded_model = (sess, run_options)
|
||||||
loaded_models[model_file_path] = loaded_model
|
loaded_models[model_cached_tag] = loaded_model
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
class TextRecognizer:
|
class TextRecognizer:
|
||||||
def __init__(self, model_dir):
|
def __init__(self, model_dir, device_id: int | None = None):
|
||||||
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
||||||
self.rec_batch_num = 16
|
self.rec_batch_num = 16
|
||||||
postprocess_params = {
|
postprocess_params = {
|
||||||
@ -132,7 +134,7 @@ class TextRecognizer:
|
|||||||
"use_space_char": True
|
"use_space_char": True
|
||||||
}
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.run_options = load_model(model_dir, 'rec')
|
self.predictor, self.run_options = load_model(model_dir, 'rec', device_id)
|
||||||
self.input_tensor = self.predictor.get_inputs()[0]
|
self.input_tensor = self.predictor.get_inputs()[0]
|
||||||
|
|
||||||
def resize_norm_img(self, img, max_wh_ratio):
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
@ -394,7 +396,7 @@ class TextRecognizer:
|
|||||||
|
|
||||||
|
|
||||||
class TextDetector:
|
class TextDetector:
|
||||||
def __init__(self, model_dir):
|
def __init__(self, model_dir, device_id: int | None = None):
|
||||||
pre_process_list = [{
|
pre_process_list = [{
|
||||||
'DetResizeForTest': {
|
'DetResizeForTest': {
|
||||||
'limit_side_len': 960,
|
'limit_side_len': 960,
|
||||||
@ -418,7 +420,7 @@ class TextDetector:
|
|||||||
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
||||||
|
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.run_options = load_model(model_dir, 'det')
|
self.predictor, self.run_options = load_model(model_dir, 'det', device_id)
|
||||||
self.input_tensor = self.predictor.get_inputs()[0]
|
self.input_tensor = self.predictor.get_inputs()[0]
|
||||||
|
|
||||||
img_h, img_w = self.input_tensor.shape[2:]
|
img_h, img_w = self.input_tensor.shape[2:]
|
||||||
@ -507,7 +509,7 @@ class TextDetector:
|
|||||||
|
|
||||||
|
|
||||||
class OCR:
|
class OCR:
|
||||||
def __init__(self, model_dir=None):
|
def __init__(self, model_dir=None, parallel_devices: int | None = None):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
@ -524,14 +526,33 @@ class OCR:
|
|||||||
model_dir = os.path.join(
|
model_dir = os.path.join(
|
||||||
get_project_base_directory(),
|
get_project_base_directory(),
|
||||||
"rag/res/deepdoc")
|
"rag/res/deepdoc")
|
||||||
self.text_detector = TextDetector(model_dir)
|
|
||||||
self.text_recognizer = TextRecognizer(model_dir)
|
# Append muti-gpus task to the list
|
||||||
|
if parallel_devices is not None and parallel_devices > 0:
|
||||||
|
self.text_detector = []
|
||||||
|
self.text_recognizer = []
|
||||||
|
for device_id in range(parallel_devices):
|
||||||
|
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||||
|
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||||
|
else:
|
||||||
|
self.text_detector = [TextDetector(model_dir, 0)]
|
||||||
|
self.text_recognizer = [TextRecognizer(model_dir, 0)]
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
local_dir_use_symlinks=False)
|
local_dir_use_symlinks=False)
|
||||||
self.text_detector = TextDetector(model_dir)
|
|
||||||
self.text_recognizer = TextRecognizer(model_dir)
|
if parallel_devices is not None:
|
||||||
|
assert parallel_devices > 0 , "Number of devices must be >= 1"
|
||||||
|
self.text_detector = []
|
||||||
|
self.text_recognizer = []
|
||||||
|
for device_id in range(parallel_devices):
|
||||||
|
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||||
|
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||||
|
else:
|
||||||
|
self.text_detector = [TextDetector(model_dir, 0)]
|
||||||
|
self.text_recognizer = [TextRecognizer(model_dir, 0)]
|
||||||
|
|
||||||
self.drop_score = 0.5
|
self.drop_score = 0.5
|
||||||
self.crop_image_res_index = 0
|
self.crop_image_res_index = 0
|
||||||
@ -593,14 +614,17 @@ class OCR:
|
|||||||
break
|
break
|
||||||
return _boxes
|
return _boxes
|
||||||
|
|
||||||
def detect(self, img):
|
def detect(self, img, device_id: int | None = None):
|
||||||
|
if device_id is None:
|
||||||
|
device_id = 0
|
||||||
|
|
||||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||||
|
|
||||||
if img is None:
|
if img is None:
|
||||||
return None, None, time_dict
|
return None, None, time_dict
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
dt_boxes, elapse = self.text_detector(img)
|
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||||
time_dict['det'] = elapse
|
time_dict['det'] = elapse
|
||||||
|
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
@ -611,17 +635,22 @@ class OCR:
|
|||||||
return zip(self.sorted_boxes(dt_boxes), [
|
return zip(self.sorted_boxes(dt_boxes), [
|
||||||
("", 0) for _ in range(len(dt_boxes))])
|
("", 0) for _ in range(len(dt_boxes))])
|
||||||
|
|
||||||
def recognize(self, ori_im, box):
|
def recognize(self, ori_im, box, device_id: int | None = None):
|
||||||
|
if device_id is None:
|
||||||
|
device_id = 0
|
||||||
|
|
||||||
img_crop = self.get_rotate_crop_image(ori_im, box)
|
img_crop = self.get_rotate_crop_image(ori_im, box)
|
||||||
|
|
||||||
rec_res, elapse = self.text_recognizer([img_crop])
|
rec_res, elapse = self.text_recognizer[device_id]([img_crop])
|
||||||
text, score = rec_res[0]
|
text, score = rec_res[0]
|
||||||
if score < self.drop_score:
|
if score < self.drop_score:
|
||||||
return ""
|
return ""
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def recognize_batch(self, img_list):
|
def recognize_batch(self, img_list, device_id: int | None = None):
|
||||||
rec_res, elapse = self.text_recognizer(img_list)
|
if device_id is None:
|
||||||
|
device_id = 0
|
||||||
|
rec_res, elapse = self.text_recognizer[device_id](img_list)
|
||||||
texts = []
|
texts = []
|
||||||
for i in range(len(rec_res)):
|
for i in range(len(rec_res)):
|
||||||
text, score = rec_res[i]
|
text, score = rec_res[i]
|
||||||
@ -630,15 +659,17 @@ class OCR:
|
|||||||
texts.append(text)
|
texts.append(text)
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
def __call__(self, img, cls=True):
|
def __call__(self, img, device_id = 0, cls=True):
|
||||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||||
|
if device_id is None:
|
||||||
|
device_id = 0
|
||||||
|
|
||||||
if img is None:
|
if img is None:
|
||||||
return None, None, time_dict
|
return None, None, time_dict
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
dt_boxes, elapse = self.text_detector(img)
|
dt_boxes, elapse = self.text_detector[device_id](img)
|
||||||
time_dict['det'] = elapse
|
time_dict['det'] = elapse
|
||||||
|
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
@ -655,7 +686,7 @@ class OCR:
|
|||||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||||
img_crop_list.append(img_crop)
|
img_crop_list.append(img_crop)
|
||||||
|
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer[device_id](img_crop_list)
|
||||||
|
|
||||||
time_dict['rec'] = elapse
|
time_dict['rec'] = elapse
|
||||||
|
|
||||||
|
@ -28,14 +28,24 @@ from deepdoc.vision.seeit import draw_box
|
|||||||
from deepdoc.vision import OCR, init_in_out
|
from deepdoc.vision import OCR, init_in_out
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import trio
|
||||||
|
|
||||||
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2' #2 gpus, uncontinuous
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
|
||||||
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
ocr = OCR()
|
import torch.cuda
|
||||||
|
|
||||||
|
cuda_devices = torch.cuda.device_count()
|
||||||
|
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
||||||
|
ocr = OCR(parallel_devices = cuda_devices)
|
||||||
images, outputs = init_in_out(args)
|
images, outputs = init_in_out(args)
|
||||||
|
|
||||||
for i, img in enumerate(images):
|
|
||||||
bxs = ocr(np.array(img))
|
def __ocr(i, id, img):
|
||||||
|
print("Task {} start".format(i))
|
||||||
|
bxs = ocr(np.array(img), id)
|
||||||
bxs = [(line[0], line[1][0]) for line in bxs]
|
bxs = [(line[0], line[1][0]) for line in bxs]
|
||||||
bxs = [{
|
bxs = [{
|
||||||
"text": t,
|
"text": t,
|
||||||
@ -47,6 +57,30 @@ def main(args):
|
|||||||
with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f:
|
with open(outputs[i] + ".txt", "w+", encoding='utf-8') as f:
|
||||||
f.write("\n".join([o["text"] for o in bxs]))
|
f.write("\n".join([o["text"] for o in bxs]))
|
||||||
|
|
||||||
|
print("Task {} done".format(i))
|
||||||
|
|
||||||
|
async def __ocr_thread(i, id, img, limiter = None):
|
||||||
|
if limiter:
|
||||||
|
async with limiter:
|
||||||
|
print("Task {} use device {}".format(i, id))
|
||||||
|
await trio.to_thread.run_sync(lambda: __ocr(i, id, img))
|
||||||
|
else:
|
||||||
|
__ocr(i, id, img)
|
||||||
|
|
||||||
|
async def __ocr_launcher():
|
||||||
|
if cuda_devices > 1:
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
nursery.start_soon(__ocr_thread, i, i % cuda_devices, img, limiter[i % cuda_devices])
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
for i, img in enumerate(images):
|
||||||
|
await __ocr_thread(i, 0, img)
|
||||||
|
|
||||||
|
trio.run(__ocr_launcher)
|
||||||
|
|
||||||
|
print("OCR tasks are all done")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -128,6 +128,9 @@ class Docx(DocxParser):
|
|||||||
|
|
||||||
|
|
||||||
class Pdf(PdfParser):
|
class Pdf(PdfParser):
|
||||||
|
def __init__(self, parallel_devices = None):
|
||||||
|
super().__init__(parallel_devices)
|
||||||
|
|
||||||
def __call__(self, filename, binary=None, from_page=0,
|
def __call__(self, filename, binary=None, from_page=0,
|
||||||
to_page=100000, zoomin=3, callback=None):
|
to_page=100000, zoomin=3, callback=None):
|
||||||
start = timer()
|
start = timer()
|
||||||
@ -194,7 +197,7 @@ class Markdown(MarkdownParser):
|
|||||||
|
|
||||||
|
|
||||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||||
lang="Chinese", callback=None, **kwargs):
|
lang="Chinese", parallel_devices=None, callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Supported file formats are docx, pdf, excel, txt.
|
Supported file formats are docx, pdf, excel, txt.
|
||||||
This method apply the naive ways to chunk files.
|
This method apply the naive ways to chunk files.
|
||||||
@ -234,7 +237,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||||
pdf_parser = Pdf()
|
pdf_parser = Pdf(parallel_devices)
|
||||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||||
pdf_parser = PlainParser()
|
pdf_parser = PlainParser()
|
||||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||||
|
@ -100,6 +100,14 @@ MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDER
|
|||||||
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
||||||
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||||
|
|
||||||
|
PARALLEL_DEVICES = None
|
||||||
|
try:
|
||||||
|
import torch.cuda
|
||||||
|
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||||
|
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||||
|
except Exception:
|
||||||
|
logging.info("can't import package 'torch'")
|
||||||
|
|
||||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||||
def start_tracemalloc_and_snapshot(signum, frame):
|
def start_tracemalloc_and_snapshot(signum, frame):
|
||||||
if not tracemalloc.is_tracing():
|
if not tracemalloc.is_tracing():
|
||||||
@ -241,7 +249,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
try:
|
try:
|
||||||
async with chunk_limiter:
|
async with chunk_limiter:
|
||||||
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||||
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
to_page=task["to_page"], lang=task["language"], parallel_devices = PARALLEL_DEVICES, callback=progress_callback,
|
||||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||||
except TaskCanceledException:
|
except TaskCanceledException:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user