From 3411d0a2ce591067dbf3a3cebd12e8b6fda51f7c Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Wed, 5 Feb 2025 18:01:23 +0800 Subject: [PATCH] Added cuda_is_available (#4725) ### What problem does this PR solve? Added cuda_is_available ### Type of change - [x] Refactoring --- deepdoc/vision/ocr.py | 12 ++++++++++-- deepdoc/vision/recognizer.py | 13 +++++++++++-- pyproject.toml | 2 +- uv.lock | 4 ++-- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 92e343586..3c33dd634 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -27,7 +27,6 @@ from . import operators import math import numpy as np import cv2 -import torch import onnxruntime as ort from .postprocess import build_post_process @@ -72,6 +71,15 @@ def load_model(model_dir, nm): raise ValueError("not find model file path {}".format( model_file_path)) + def cuda_is_available(): + try: + import torch + if torch.cuda.is_available(): + return True + except Exception: + return False + return False + options = ort.SessionOptions() options.enable_cpu_mem_arena = False options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL @@ -81,7 +89,7 @@ def load_model(model_dir, nm): # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 # Shrink GPU memory after execution run_options = ort.RunOptions() - if torch.cuda.is_available(): + if cuda_is_available(): cuda_provider_options = { "device_id": 0, # Use specific GPU "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index e146e742d..f6052f33b 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -21,7 +21,6 @@ import numpy as np import cv2 from copy import deepcopy -import torch import onnxruntime as ort from huggingface_hub import snapshot_download @@ -60,11 +59,21 @@ class Recognizer(object): if not os.path.exists(model_file_path): raise ValueError("not find model file path {}".format( model_file_path)) + + def cuda_is_available(): + try: + import torch + if torch.cuda.is_available(): + return True + except Exception: + return False + return False + # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 # Shrink GPU memory after execution self.run_options = ort.RunOptions() - if torch.cuda.is_available(): + if cuda_is_available(): options = ort.SessionOptions() options.enable_cpu_mem_arena = False cuda_provider_options = { diff --git a/pyproject.toml b/pyproject.toml index 4e5c4a2af..11bf19043 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,6 @@ dependencies = [ "tencentcloud-sdk-python==3.0.1215", "tika==2.6.0", "tiktoken==0.7.0", - "torch>=2.5.0,<3.0.0", "umap_learn==0.5.6", "vertexai==1.64.0", "volcengine==1.0.146", @@ -132,5 +131,6 @@ full = [ "fastembed>=0.3.6,<0.4.0; sys_platform == 'darwin' or platform_machine != 'x86_64'", "fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'", "flagembedding==1.2.10", + "torch>=2.5.0,<3.0.0", "transformers>=4.35.0,<5.0.0" ] \ No newline at end of file diff --git a/uv.lock b/uv.lock index 0a09f2b34..0da921554 100644 --- a/uv.lock +++ b/uv.lock @@ -4814,7 +4814,6 @@ dependencies = [ { name = "tencentcloud-sdk-python" }, { name = "tika" }, { name = "tiktoken" }, - { name = "torch" }, { name = "umap-learn" }, { name = "valkey" }, { name = "vertexai" }, @@ -4837,6 +4836,7 @@ full = [ { name = "fastembed", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'" }, { name = "fastembed-gpu", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "flagembedding" }, + { name = "torch" }, { name = "transformers" }, ] @@ -4946,7 +4946,7 @@ requires-dist = [ { name = "tencentcloud-sdk-python", specifier = "==3.0.1215" }, { name = "tika", specifier = "==2.6.0" }, { name = "tiktoken", specifier = "==0.7.0" }, - { name = "torch", specifier = ">=2.5.0,<3.0.0" }, + { name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, { name = "umap-learn", specifier = "==0.5.6" }, { name = "valkey", specifier = "==6.0.2" },