Refa: PARALLEL_DEVICES is a static parameter. (#6168)

### What problem does this PR solve?


### Type of change

- [x] Refactoring
This commit is contained in:
Kevin Hu 2025-03-17 16:49:54 +08:00 committed by GitHub
parent 45fe02c8b3
commit 3a99c2b5f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 29 additions and 28 deletions

View File

@ -37,13 +37,15 @@ from rag.nlp import rag_tokenizer
from copy import deepcopy from copy import deepcopy
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from rag.settings import PARALLEL_DEVICES
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
if LOCK_KEY_pdfplumber not in sys.modules: if LOCK_KEY_pdfplumber not in sys.modules:
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock() sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
class RAGFlowPdfParser: class RAGFlowPdfParser:
def __init__(self, parallel_devices: int | None = None): def __init__(self):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -56,11 +58,10 @@ class RAGFlowPdfParser:
""" """
self.ocr = OCR(parallel_devices = parallel_devices) self.ocr = OCR()
self.parallel_devices = parallel_devices
self.parallel_limiter = None self.parallel_limiter = None
if parallel_devices is not None and parallel_devices > 1: if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(parallel_devices)] self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
if hasattr(self, "model_speciess"): if hasattr(self, "model_speciess"):
self.layouter = LayoutRecognizer("layout." + self.model_speciess) self.layouter = LayoutRecognizer("layout." + self.model_speciess)
@ -1019,7 +1020,6 @@ class RAGFlowPdfParser:
if not self.outlines: if not self.outlines:
logging.warning("Miss outlines") logging.warning("Miss outlines")
logging.debug("Images converted.") logging.debug("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
@ -1066,8 +1066,8 @@ class RAGFlowPdfParser:
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):
chars = __ocr_preprocess() chars = __ocr_preprocess()
nursery.start_soon(__img_ocr, i, i % self.parallel_devices, img, chars, nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
self.parallel_limiter[i % self.parallel_devices]) self.parallel_limiter[i % PARALLEL_DEVICES])
await trio.sleep(0.1) await trio.sleep(0.1)
else: else:
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):

View File

@ -22,6 +22,7 @@ import os
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.settings import PARALLEL_DEVICES
from .operators import * # noqa: F403 from .operators import * # noqa: F403
from . import operators from . import operators
import math import math
@ -509,7 +510,7 @@ class TextDetector:
class OCR: class OCR:
def __init__(self, model_dir=None, parallel_devices: int | None = None): def __init__(self, model_dir=None):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -528,10 +529,10 @@ class OCR:
"rag/res/deepdoc") "rag/res/deepdoc")
# Append muti-gpus task to the list # Append muti-gpus task to the list
if parallel_devices is not None and parallel_devices > 0: if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0:
self.text_detector = [] self.text_detector = []
self.text_recognizer = [] self.text_recognizer = []
for device_id in range(parallel_devices): for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id)) self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else: else:
@ -543,11 +544,11 @@ class OCR:
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
local_dir_use_symlinks=False) local_dir_use_symlinks=False)
if parallel_devices is not None: if PARALLEL_DEVICES is not None:
assert parallel_devices > 0 , "Number of devices must be >= 1" assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1"
self.text_detector = [] self.text_detector = []
self.text_recognizer = [] self.text_recognizer = []
for device_id in range(parallel_devices): for device_id in range(PARALLEL_DEVICES):
self.text_detector.append(TextDetector(model_dir, device_id)) self.text_detector.append(TextDetector(model_dir, device_id))
self.text_recognizer.append(TextRecognizer(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id))
else: else:

View File

@ -34,15 +34,15 @@ import trio
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu os.environ['CUDA_VISIBLE_DEVICES'] = '0' #1 gpu
# os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu # os.environ['CUDA_VISIBLE_DEVICES'] = '' #cpu
def main(args): def main(args):
import torch.cuda import torch.cuda
cuda_devices = torch.cuda.device_count() cuda_devices = torch.cuda.device_count()
limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None limiter = [trio.CapacityLimiter(1) for _ in range(cuda_devices)] if cuda_devices > 1 else None
ocr = OCR(parallel_devices = cuda_devices) ocr = OCR()
images, outputs = init_in_out(args) images, outputs = init_in_out(args)
def __ocr(i, id, img): def __ocr(i, id, img):
print("Task {} start".format(i)) print("Task {} start".format(i))
bxs = ocr(np.array(img), id) bxs = ocr(np.array(img), id)

View File

@ -128,8 +128,8 @@ class Docx(DocxParser):
class Pdf(PdfParser): class Pdf(PdfParser):
def __init__(self, parallel_devices = None): def __init__(self):
super().__init__(parallel_devices) super().__init__()
def __call__(self, filename, binary=None, from_page=0, def __call__(self, filename, binary=None, from_page=0,
to_page=100000, zoomin=3, callback=None): to_page=100000, zoomin=3, callback=None):
@ -197,7 +197,7 @@ class Markdown(MarkdownParser):
def chunk(filename, binary=None, from_page=0, to_page=100000, def chunk(filename, binary=None, from_page=0, to_page=100000,
lang="Chinese", callback=None, parallel_devices=None, **kwargs): lang="Chinese", callback=None, **kwargs):
""" """
Supported file formats are docx, pdf, excel, txt. Supported file formats are docx, pdf, excel, txt.
This method apply the naive ways to chunk files. This method apply the naive ways to chunk files.
@ -237,7 +237,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf(parallel_devices) pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text": if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser() pdf_parser = PlainParser()
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,

View File

@ -39,6 +39,13 @@ SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker"
PAGERANK_FLD = "pagerank_fea" PAGERANK_FLD = "pagerank_fea"
TAG_FLD = "tag_feas" TAG_FLD = "tag_feas"
PARALLEL_DEVICES = None
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
def print_rag_settings(): def print_rag_settings():
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")

View File

@ -100,13 +100,6 @@ MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDER
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS)
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS)
PARALLEL_DEVICES = None
try:
import torch.cuda
PARALLEL_DEVICES = torch.cuda.device_count()
logging.info(f"found {PARALLEL_DEVICES} gpus")
except Exception:
logging.info("can't import package 'torch'")
# SIGUSR1 handler: start tracemalloc and take snapshot # SIGUSR1 handler: start tracemalloc and take snapshot
def start_tracemalloc_and_snapshot(signum, frame): def start_tracemalloc_and_snapshot(signum, frame):
@ -249,7 +242,7 @@ async def build_chunks(task, progress_callback):
try: try:
async with chunk_limiter: async with chunk_limiter:
cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"], cks = await trio.to_thread.run_sync(lambda: chunker.chunk(task["name"], binary=binary, from_page=task["from_page"],
to_page=task["to_page"], lang=task["language"], parallel_devices = PARALLEL_DEVICES, callback=progress_callback, to_page=task["to_page"], lang=task["language"], callback=progress_callback,
kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"])) kb_id=task["kb_id"], parser_config=task["parser_config"], tenant_id=task["tenant_id"]))
logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"]))
except TaskCanceledException: except TaskCanceledException: