mirror of
https://git.mirrors.martin98.com/https://github.com/infiniflow/ragflow.git
synced 2025-08-20 08:09:26 +08:00
Refa: PARALLEL_DEVICES is a static parameter. (#6168)
### What problem does this PR solve? ### Type of change - [x] Refactoring
This commit is contained in:
parent
45fe02c8b3
commit
3a99c2b5f4
@ -37,13 +37,15 @@ from rag.nlp import rag_tokenizer
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from rag.settings import PARALLEL_DEVICES
|
||||||
|
|
||||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
class RAGFlowPdfParser:
|
class RAGFlowPdfParser:
|
||||||
def __init__(self, parallel_devices: int | None = None):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
@ -56,11 +58,10 @@ class RAGFlowPdfParser:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.ocr = OCR(parallel_devices = parallel_devices)
|
self.ocr = OCR()
|
||||||
self.parallel_devices = parallel_devices
|
|
||||||
self.parallel_limiter = None
|
self.parallel_limiter = None
|
||||||
if parallel_devices is not None and parallel_devices > 1:
|
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1:
|
||||||
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)]
|
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
|
||||||
|
|
||||||
if hasattr(self, "model_speciess"):
|
if hasattr(self, "model_speciess"):
|
||||||
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
||||||
@ -1019,7 +1020,6 @@ class RAGFlowPdfParser:
|
|||||||
if not self.outlines:
|
if not self.outlines:
|
||||||
logging.warning("Miss outlines")
|
logging.warning("Miss outlines")
|
||||||
|
|
||||||
|
|
||||||
logging.debug("Images converted.")
|
logging.debug("Images converted.")
|
||||||
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
||||||
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
||||||
@ -1066,8 +1066,8 @@ class RAGFlowPdfParser:
|
|||||||
for i, img in enumerate(self.page_images):
|
for i, img in enumerate(self.page_images):
|
||||||
chars = __ocr_preprocess()
|
chars = __ocr_preprocess()
|
||||||
|
|
||||||
nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars,
|
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
|
||||||
self.parallel_limiter[i % self.parallel_devices])
|
self.parallel_limiter[i % PARALLEL_DEVICES])
|
||||||
await trio.sleep(0.1)
|
await trio.sleep(0.1)
|
||||||
else:
|
else:
|
||||||
for i, img in enumerate(self.page_images):
|
for i, img in enumerate(self.page_images):
|
||||||
|
@ -22,6 +22,7 @@ import os
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
|
from rag.settings import PARALLEL_DEVICES
|
||||||
from .operators import * # noqa: F403
|
from .operators import * # noqa: F403
|
||||||
from . import operators
|
from . import operators
|
||||||
import math
|
import math
|
||||||
@ -509,7 +510,7 @@ class TextDetector:
|
|||||||
|
|
||||||
|
|
||||||
class OCR:
|
class OCR:
|
||||||
def __init__(self, model_dir=None, parallel_devices: int | None = None):
|
def __init__(self, model_dir=None):
|
||||||
"""
|
"""
|
||||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||||
|
|
||||||
@ -528,10 +529,10 @@ class OCR:
|
|||||||
"rag/res/deepdoc")
|
"rag/res/deepdoc")
|
||||||
|
|
||||||
# Append muti-gpus task to the list
|
# Append muti-gpus task to the list
|
||||||
if parallel_devices is not None and parallel_devices > 0:
|
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0:
|
||||||
self.text_detector = []
|
self.text_detector = []
|
||||||
self.text_recognizer = []
|
self.text_recognizer = []
|
||||||
for device_id in range(parallel_devices):
|
for device_id in range(PARALLEL_DEVICES):
|
||||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||||
else:
|
else:
|
||||||
@ -543,11 +544,11 @@ class OCR:
|
|||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
local_dir_use_symlinks=False)
|
local_dir_use_symlinks=False)
|
||||||
|
|
||||||
if parallel_devices is not None:
|
if PARALLEL_DEVICES is not None:
|
||||||
assert parallel_devices > 0 , "Number of devices must be >= 1"
|
assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1"
|
||||||
self.text_detector = []
|
self.text_detector = []
|
||||||
self.text_recognizer = []
|
self.text_recognizer = []
|
||||||
for device_id in range(parallel_devices):
|
for device_id in range(PARALLEL_DEVICES):
|
||||||
self.text_detector.append(TextDetector(model_dir, device_id))
|
self.text_detector.append(TextDetector(model_dir, device_id))
|
||||||
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
||||||
else:
|
else:
|
||||||
|
@ -34,15 +34,15 @@ import trio
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
|
||||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
|
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
cuda_devices = torch.cuda.device_count()
|
cuda_devices = torch.cuda.device_count()
|
||||||
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
|
||||||
ocr = OCR(parallel_devices = cuda_devices)
|
ocr = OCR()
|
||||||
images, outputs = init_in_out(args)
|
images, outputs = init_in_out(args)
|
||||||
|
|
||||||
|
|
||||||
def __ocr(i, id, img):
|
def __ocr(i, id, img):
|
||||||
print("Task {} start".format(i))
|
print("Task {} start".format(i))
|
||||||
bxs = ocr(np.array(img), id)
|
bxs = ocr(np.array(img), id)
|
||||||
|
@ -128,8 +128,8 @@ class Docx(DocxParser):
|
|||||||
|
|
||||||
|
|
||||||
class Pdf(PdfParser):
|
class Pdf(PdfParser):
|
||||||
def __init__(self, parallel_devices = None):
|
def __init__(self):
|
||||||
super().__init__(parallel_devices)
|
super().__init__()
|
||||||
|
|
||||||
def __call__(self, filename, binary=None, from_page=0,
|
def __call__(self, filename, binary=None, from_page=0,
|
||||||
to_page=100000, zoomin=3, callback=None):
|
to_page=100000, zoomin=3, callback=None):
|
||||||
@ -197,7 +197,7 @@ class Markdown(MarkdownParser):
|
|||||||
|
|
||||||
|
|
||||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||||
lang="Chinese", callback=None, parallel_devices=None, **kwargs):
|
lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Supported file formats are docx, pdf, excel, txt.
|
Supported file formats are docx, pdf, excel, txt.
|
||||||
This method apply the naive ways to chunk files.
|
This method apply the naive ways to chunk files.
|
||||||
@ -237,7 +237,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||||
pdf_parser = Pdf(parallel_devices)
|
pdf_parser = Pdf()
|
||||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||||
pdf_parser = PlainParser()
|
pdf_parser = PlainParser()
|
||||||
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
|
||||||
|
@ -39,6 +39,13 @@ SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
|
|||||||
PAGERANK_FLD = "pagerank_fea"
|
PAGERANK_FLD = "pagerank_fea"
|
||||||
TAG_FLD = "tag_feas"
|
TAG_FLD = "tag_feas"
|
||||||
|
|
||||||
|
PARALLEL_DEVICES = None
|
||||||
|
try:
|
||||||
|
import torch.cuda
|
||||||
|
PARALLEL_DEVICES = torch.cuda.device_count()
|
||||||
|
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
||||||
|
except Exception:
|
||||||
|
logging.info("can't import package 'torch'")
|
||||||
|
|
||||||
def print_rag_settings():
|
def print_rag_settings():
|
||||||
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
||||||
|
@ -100,13 +100,6 @@ MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDER
|
|||||||
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
|
||||||
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
|
||||||
|
|
||||||
PARALLEL_DEVICES = None
|
|
||||||
try:
|
|
||||||
import torch.cuda
|
|
||||||
PARALLEL_DEVICES = torch.cuda.device_count()
|
|
||||||
logging.info(f"found {PARALLEL_DEVICES} gpus")
|
|
||||||
except Exception:
|
|
||||||
logging.info("can't import package 'torch'")
|
|
||||||
|
|
||||||
# SIGUSR1 handler: start tracemalloc and take snapshot
|
# SIGUSR1 handler: start tracemalloc and take snapshot
|
||||||
def start_tracemalloc_and_snapshot(signum, frame):
|
def start_tracemalloc_and_snapshot(signum, frame):
|
||||||
@ -249,7 +242,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
try:
|
try:
|
||||||
async with chunk_limiter:
|
async with chunk_limiter:
|
||||||
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
|
||||||
to_page=task["to_page"], lang=task["language"], parallel_devices = PARALLEL_DEVICES, callback=progress_callback,
|
to_page=task["to_page"], lang=task["language"], callback=progress_callback,
|
||||||
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
|
||||||
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
|
||||||
except TaskCanceledException:
|
except TaskCanceledException:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user