diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 492c4dc54..dbc6cc320 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -61,7 +61,7 @@ class RAGFlowPdfParser: self.ocr = OCR() self.parallel_limiter = None - if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1: + if PARALLEL_DEVICES > 1: self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] if hasattr(self, "model_speciess"): diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 4dedb7c67..90b11038f 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -529,31 +529,30 @@ class OCR: "rag/res/deepdoc") # Append muti-gpus task to the list - if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 0: + if PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] for device_id in range(PARALLEL_DEVICES): self.text_detector.append(TextDetector(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id)) else: - self.text_detector = [TextDetector(model_dir, 0)] - self.text_recognizer = [TextRecognizer(model_dir, 0)] + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] except Exception: model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc", local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), local_dir_use_symlinks=False) - if PARALLEL_DEVICES is not None: - assert PARALLEL_DEVICES > 0, "Number of devices must be >= 1" + if PARALLEL_DEVICES > 0: self.text_detector = [] self.text_recognizer = [] for device_id in range(PARALLEL_DEVICES): self.text_detector.append(TextDetector(model_dir, device_id)) self.text_recognizer.append(TextRecognizer(model_dir, device_id)) else: - self.text_detector = [TextDetector(model_dir, 0)] - self.text_recognizer = [TextRecognizer(model_dir, 0)] + self.text_detector = [TextDetector(model_dir)] + self.text_recognizer = [TextRecognizer(model_dir)] self.drop_score = 0.5 self.crop_image_res_index = 0 diff --git a/rag/settings.py b/rag/settings.py index 2dfaea627..601b4e3f4 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -62,7 +62,7 @@ SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" PAGERANK_FLD = "pagerank_fea" TAG_FLD = "tag_feas" -PARALLEL_DEVICES = None +PARALLEL_DEVICES = 0 try: import torch.cuda PARALLEL_DEVICES = torch.cuda.device_count()