From aa6fb76e0f25461d5ec3c9ff6f5f3a0751aeee11 Mon Sep 17 00:00:00 2001 From: catscarlet Date: Fri, 28 Feb 2025 17:06:42 +0800 Subject: [PATCH] ensure CUDAExecutionProvider is called in session when using nvidia gpu. --- rembg/sessions/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rembg/sessions/base.py b/rembg/sessions/base.py index bfcb115..0d97d44 100644 --- a/rembg/sessions/base.py +++ b/rembg/sessions/base.py @@ -13,9 +13,17 @@ class BaseSession: def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): """Initialize an instance of the BaseSession class.""" self.model_name = model_name + + device_type = ort.get_device() + if device_type == 'GPU' and 'CUDAExecutionProvider' in ort.get_available_providers(): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + self.inner_session = ort.InferenceSession( str(self.__class__.download_models(*args, **kwargs)), sess_options=sess_opts, + providers=providers, ) def normalize(