Added cuda_is_available (#4725)

### What problem does this PR solve?

Added cuda_is_available

### Type of change

- [x] Refactoring
This commit is contained in:
Zhichang Yu 2025-02-05 18:01:23 +08:00 committed by GitHub
parent 283d036cba
commit 3411d0a2ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 7 deletions

View File

@ -27,7 +27,6 @@ from . import operators
import math import math
import numpy as np import numpy as np
import cv2 import cv2
import torch
import onnxruntime as ort import onnxruntime as ort
from .postprocess import build_post_process from .postprocess import build_post_process
@ -72,6 +71,15 @@ def load_model(model_dir, nm):
raise ValueError("not find model file path {}".format( raise ValueError("not find model file path {}".format(
model_file_path)) model_file_path))
def cuda_is_available():
try:
import torch
if torch.cuda.is_available():
return True
except Exception:
return False
return False
options = ort.SessionOptions() options = ort.SessionOptions()
options.enable_cpu_mem_arena = False options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
@ -81,7 +89,7 @@ def load_model(model_dir, nm):
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution # Shrink GPU memory after execution
run_options = ort.RunOptions() run_options = ort.RunOptions()
if torch.cuda.is_available(): if cuda_is_available():
cuda_provider_options = { cuda_provider_options = {
"device_id": 0, # Use specific GPU "device_id": 0, # Use specific GPU
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory

View File

@ -21,7 +21,6 @@ import numpy as np
import cv2 import cv2
from copy import deepcopy from copy import deepcopy
import torch
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -60,11 +59,21 @@ class Recognizer(object):
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format( raise ValueError("not find model file path {}".format(
model_file_path)) model_file_path))
def cuda_is_available():
try:
import torch
if torch.cuda.is_available():
return True
except Exception:
return False
return False
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution # Shrink GPU memory after execution
self.run_options = ort.RunOptions() self.run_options = ort.RunOptions()
if torch.cuda.is_available(): if cuda_is_available():
options = ort.SessionOptions() options = ort.SessionOptions()
options.enable_cpu_mem_arena = False options.enable_cpu_mem_arena = False
cuda_provider_options = { cuda_provider_options = {

View File

@ -100,7 +100,6 @@ dependencies = [
"tencentcloud-sdk-python==3.0.1215", "tencentcloud-sdk-python==3.0.1215",
"tika==2.6.0", "tika==2.6.0",
"tiktoken==0.7.0", "tiktoken==0.7.0",
"torch>=2.5.0,<3.0.0",
"umap_learn==0.5.6", "umap_learn==0.5.6",
"vertexai==1.64.0", "vertexai==1.64.0",
"volcengine==1.0.146", "volcengine==1.0.146",
@ -132,5 +131,6 @@ full = [
"fastembed>=0.3.6,<0.4.0; sys_platform == 'darwin' or platform_machine != 'x86_64'", "fastembed>=0.3.6,<0.4.0; sys_platform == 'darwin' or platform_machine != 'x86_64'",
"fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'", "fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'",
"flagembedding==1.2.10", "flagembedding==1.2.10",
"torch>=2.5.0,<3.0.0",
"transformers>=4.35.0,<5.0.0" "transformers>=4.35.0,<5.0.0"
] ]

4
uv.lock generated
View File

@ -4814,7 +4814,6 @@ dependencies = [
{ name = "tencentcloud-sdk-python" }, { name = "tencentcloud-sdk-python" },
{ name = "tika" }, { name = "tika" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "torch" },
{ name = "umap-learn" }, { name = "umap-learn" },
{ name = "valkey" }, { name = "valkey" },
{ name = "vertexai" }, { name = "vertexai" },
@ -4837,6 +4836,7 @@ full = [
{ name = "fastembed", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'" }, { name = "fastembed", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'" },
{ name = "fastembed-gpu", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "fastembed-gpu", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" },
{ name = "flagembedding" }, { name = "flagembedding" },
{ name = "torch" },
{ name = "transformers" }, { name = "transformers" },
] ]
@ -4946,7 +4946,7 @@ requires-dist = [
{ name = "tencentcloud-sdk-python", specifier = "==3.0.1215" }, { name = "tencentcloud-sdk-python", specifier = "==3.0.1215" },
{ name = "tika", specifier = "==2.6.0" }, { name = "tika", specifier = "==2.6.0" },
{ name = "tiktoken", specifier = "==0.7.0" }, { name = "tiktoken", specifier = "==0.7.0" },
{ name = "torch", specifier = ">=2.5.0,<3.0.0" }, { name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" },
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
{ name = "umap-learn", specifier = "==0.5.6" }, { name = "umap-learn", specifier = "==0.5.6" },
{ name = "valkey", specifier = "==6.0.2" }, { name = "valkey", specifier = "==6.0.2" },