From ac0fa9abb3fc49e5425d50b772ab715446a0907f Mon Sep 17 00:00:00 2001 From: Yahweasel Date: Tue, 13 May 2025 08:15:52 -0400 Subject: [PATCH] Fix (and document) support for ROCM backend --- README.md | 15 ++++++++++++++- rembg/sessions/base.py | 5 +++++ setup.py | 1 + 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 76bc374..59e1e09 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ pip install rembg[cpu] # for library pip install "rembg[cpu,cli]" # for library + cli ``` -### GPU support: +### GPU support (NVidia/Cuda): First of all, you need to check if your system supports the `onnxruntime-gpu`. @@ -126,6 +126,19 @@ pip install "rembg[gpu,cli]" # for library + cli Nvidia GPU may require onnxruntime-gpu, cuda, and cudnn-devel. [#668](https://github.com/danielgatis/rembg/issues/668#issuecomment-2689830314) . If rembg[gpu] doesn't work and you can't install cuda or cudnn-devel, use rembg[cpu] and onnxruntime instead. +### GPU support (AMD/ROCM): + +ROCM support requires the `onnxruntime-rocm` package. Install it following +[AMD's documentation](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-onnx.html). + +If `onnxruntime-rocm` is installed and working, install the `rembg[rocm]` +version of rembg: + +```bash +pip install "rembg[rocm]" # for library +pip install "rembg[rocm,cli]" # for library + cli +``` + ## Usage as a cli After the installation step you can use rembg just typing `rembg` in your terminal window. diff --git a/rembg/sessions/base.py b/rembg/sessions/base.py index c932ff7..2e2bcd5 100644 --- a/rembg/sessions/base.py +++ b/rembg/sessions/base.py @@ -20,6 +20,11 @@ class BaseSession: and "CUDAExecutionProvider" in ort.get_available_providers() ): providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + elif ( + device_type[0:3] == "GPU" + and "ROCMExecutionProvider" in ort.get_available_providers() + ): + providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] diff --git a/setup.py b/setup.py index 0b0b9af..bb34afb 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ extras_require = { ], "cpu": ["onnxruntime"], "gpu": ["onnxruntime-gpu"], + "rocm": ["onnxruntime-rocm"], "cli": [ "aiohttp", "asyncer",