Fixed GPU detection on CPU only environment (#4711)

### What problem does this PR solve?

Fixed GPU detection on CPU only environment. Close #4692

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Zhichang Yu 2025-02-05 12:02:43 +08:00 committed by GitHub
parent 7a7f98b1a9
commit e1526846da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 7 additions and 6 deletions

View File

@ -27,6 +27,7 @@ from . import operators
import math import math
import numpy as np import numpy as np
import cv2 import cv2
import torch
import onnxruntime as ort import onnxruntime as ort
from .postprocess import build_post_process from .postprocess import build_post_process
@ -80,7 +81,7 @@ def load_model(model_dir, nm):
# https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580
# Shrink GPU memory after execution # Shrink GPU memory after execution
run_options = ort.RunOptions() run_options = ort.RunOptions()
if ort.get_device() == "GPU": if torch.cuda.is_available():
cuda_provider_options = { cuda_provider_options = {
"device_id": 0, # Use specific GPU "device_id": 0, # Use specific GPU
"gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory

View File

@ -21,7 +21,7 @@ import numpy as np
import cv2 import cv2
from copy import deepcopy from copy import deepcopy
import torch
import onnxruntime as ort import onnxruntime as ort
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -64,7 +64,7 @@ class Recognizer(object):
# Shrink GPU memory after execution # Shrink GPU memory after execution
self.run_options = ort.RunOptions() self.run_options = ort.RunOptions()
if ort.get_device() == "GPU": if torch.cuda.is_available():
options = ort.SessionOptions() options = ort.SessionOptions()
options.enable_cpu_mem_arena = False options.enable_cpu_mem_arena = False
cuda_provider_options = { cuda_provider_options = {

View File

@ -100,6 +100,7 @@ dependencies = [
"tencentcloud-sdk-python==3.0.1215", "tencentcloud-sdk-python==3.0.1215",
"tika==2.6.0", "tika==2.6.0",
"tiktoken==0.7.0", "tiktoken==0.7.0",
"torch>=2.5.0,<3.0.0",
"umap_learn==0.5.6", "umap_learn==0.5.6",
"vertexai==1.64.0", "vertexai==1.64.0",
"volcengine==1.0.146", "volcengine==1.0.146",
@ -131,6 +132,5 @@ full = [
"fastembed>=0.3.6,<0.4.0; sys_platform == 'darwin' or platform_machine != 'x86_64'", "fastembed>=0.3.6,<0.4.0; sys_platform == 'darwin' or platform_machine != 'x86_64'",
"fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'", "fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'",
"flagembedding==1.2.10", "flagembedding==1.2.10",
"torch>=2.5.0,<3.0.0",
"transformers>=4.35.0,<5.0.0" "transformers>=4.35.0,<5.0.0"
] ]

4
uv.lock generated
View File

@ -4814,6 +4814,7 @@ dependencies = [
{ name = "tencentcloud-sdk-python" }, { name = "tencentcloud-sdk-python" },
{ name = "tika" }, { name = "tika" },
{ name = "tiktoken" }, { name = "tiktoken" },
{ name = "torch" },
{ name = "umap-learn" }, { name = "umap-learn" },
{ name = "valkey" }, { name = "valkey" },
{ name = "vertexai" }, { name = "vertexai" },
@ -4836,7 +4837,6 @@ full = [
{ name = "fastembed", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'" }, { name = "fastembed", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'" },
{ name = "fastembed-gpu", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "fastembed-gpu", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" },
{ name = "flagembedding" }, { name = "flagembedding" },
{ name = "torch" },
{ name = "transformers" }, { name = "transformers" },
] ]
@ -4946,7 +4946,7 @@ requires-dist = [
{ name = "tencentcloud-sdk-python", specifier = "==3.0.1215" }, { name = "tencentcloud-sdk-python", specifier = "==3.0.1215" },
{ name = "tika", specifier = "==2.6.0" }, { name = "tika", specifier = "==2.6.0" },
{ name = "tiktoken", specifier = "==0.7.0" }, { name = "tiktoken", specifier = "==0.7.0" },
{ name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" }, { name = "torch", specifier = ">=2.5.0,<3.0.0" },
{ name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" },
{ name = "umap-learn", specifier = "==0.5.6" }, { name = "umap-learn", specifier = "==0.5.6" },
{ name = "valkey", specifier = "==6.0.2" }, { name = "valkey", specifier = "==6.0.2" },