diff --git a/deepdoc/vision/ocr.py b/deepdoc/vision/ocr.py index 4150adad8..92e343586 100644 --- a/deepdoc/vision/ocr.py +++ b/deepdoc/vision/ocr.py @@ -27,6 +27,7 @@ from . import operators import math import numpy as np import cv2 +import torch import onnxruntime as ort from .postprocess import build_post_process @@ -80,7 +81,7 @@ def load_model(model_dir, nm): # https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 # Shrink GPU memory after execution run_options = ort.RunOptions() - if ort.get_device() == "GPU": + if torch.cuda.is_available(): cuda_provider_options = { "device_id": 0, # Use specific GPU "gpu_mem_limit": 512 * 1024 * 1024, # Limit gpu memory diff --git a/deepdoc/vision/recognizer.py b/deepdoc/vision/recognizer.py index 54ba68d65..e146e742d 100644 --- a/deepdoc/vision/recognizer.py +++ b/deepdoc/vision/recognizer.py @@ -21,7 +21,7 @@ import numpy as np import cv2 from copy import deepcopy - +import torch import onnxruntime as ort from huggingface_hub import snapshot_download @@ -64,7 +64,7 @@ class Recognizer(object): # Shrink GPU memory after execution self.run_options = ort.RunOptions() - if ort.get_device() == "GPU": + if torch.cuda.is_available(): options = ort.SessionOptions() options.enable_cpu_mem_arena = False cuda_provider_options = { diff --git a/pyproject.toml b/pyproject.toml index 11bf19043..4e5c4a2af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ dependencies = [ "tencentcloud-sdk-python==3.0.1215", "tika==2.6.0", "tiktoken==0.7.0", + "torch>=2.5.0,<3.0.0", "umap_learn==0.5.6", "vertexai==1.64.0", "volcengine==1.0.146", @@ -131,6 +132,5 @@ full = [ "fastembed>=0.3.6,<0.4.0; sys_platform == 'darwin' or platform_machine != 'x86_64'", "fastembed-gpu>=0.3.6,<0.4.0; sys_platform != 'darwin' and platform_machine == 'x86_64'", "flagembedding==1.2.10", - "torch>=2.5.0,<3.0.0", "transformers>=4.35.0,<5.0.0" ] \ No newline at end of file diff --git a/uv.lock b/uv.lock index 0da921554..0a09f2b34 100644 --- a/uv.lock +++ b/uv.lock @@ -4814,6 +4814,7 @@ dependencies = [ { name = "tencentcloud-sdk-python" }, { name = "tika" }, { name = "tiktoken" }, + { name = "torch" }, { name = "umap-learn" }, { name = "valkey" }, { name = "vertexai" }, @@ -4836,7 +4837,6 @@ full = [ { name = "fastembed", marker = "platform_machine != 'x86_64' or sys_platform == 'darwin'" }, { name = "fastembed-gpu", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "flagembedding" }, - { name = "torch" }, { name = "transformers" }, ] @@ -4946,7 +4946,7 @@ requires-dist = [ { name = "tencentcloud-sdk-python", specifier = "==3.0.1215" }, { name = "tika", specifier = "==2.6.0" }, { name = "tiktoken", specifier = "==0.7.0" }, - { name = "torch", marker = "extra == 'full'", specifier = ">=2.5.0,<3.0.0" }, + { name = "torch", specifier = ">=2.5.0,<3.0.0" }, { name = "transformers", marker = "extra == 'full'", specifier = ">=4.35.0,<5.0.0" }, { name = "umap-learn", specifier = "==0.5.6" }, { name = "valkey", specifier = "==6.0.2" },