From 4d24dded1636486f028d3fff6b61b37a79f878eb Mon Sep 17 00:00:00 2001 From: Paul Gross Date: Thu, 13 Mar 2025 09:31:23 -0700 Subject: [PATCH] Only download u2net model in docker image With the addition of the extra models, the docker image is huge (> 17 GB) and too big to build and push in GitHub actions, especially for `arm64` (which is slower to build than `amd64`). This change makes it so the `d` command can take a list of models to download, and the Dockerfile only downloads the default `u2net` model. This shrinks the docker image to about 3 GB. More discussion at: https://github.com/danielgatis/rembg/discussions/735 --- .github/workflows/publish_docker.yml | 4 +- Dockerfile | 2 +- rembg/bg.py | 21 +++++++--- rembg/commands/d_command.py | 9 +++-- rembg/sessions/__init__.py | 59 +++++++++++----------------- 5 files changed, 46 insertions(+), 49 deletions(-) diff --git a/.github/workflows/publish_docker.yml b/.github/workflows/publish_docker.yml index bf95da3..8c7a80e 100644 --- a/.github/workflows/publish_docker.yml +++ b/.github/workflows/publish_docker.yml @@ -8,7 +8,7 @@ on: jobs: publish_docker: name: Push Docker image to Docker Hub - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - name: Checkout uses: actions/checkout@v4 @@ -43,7 +43,7 @@ jobs: password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} - name: Build and push - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . platforms: linux/amd64,linux/arm64 diff --git a/Dockerfile b/Dockerfile index 2c61274..ba09752 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/li COPY . . RUN python -m pip install ".[cpu,cli]" -RUN rembg d +RUN rembg d u2net EXPOSE 7000 ENTRYPOINT ["rembg"] diff --git a/rembg/bg.py b/rembg/bg.py index 2dc2af3..554096a 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -1,4 +1,5 @@ import io +import sys from enum import Enum from typing import Any, List, Optional, Tuple, Union, cast @@ -20,7 +21,7 @@ from pymatting.util.util import stack_images from scipy.ndimage import binary_erosion from .session_factory import new_session -from .sessions import sessions_class +from .sessions import sessions, sessions_names from .sessions.base import BaseSession ort.set_default_logger_severity(3) @@ -194,12 +195,22 @@ def fix_image_orientation(img: PILImage) -> PILImage: return cast(PILImage, ImageOps.exif_transpose(img)) -def download_models() -> None: +def download_models(models: tuple[str, ...]) -> None: """ Download models for image processing. """ - for session in sessions_class: - session.download_models() + if len(models) == 0: + print("No models specified, downloading all models") + models = tuple(sessions_names) + + for model in models: + session = sessions.get(model) + if session is None: + print(f"Error: no model found: {model}") + sys.exit(1) + else: + print(f"Downloading model: {model}") + session.download_models() def remove( @@ -214,7 +225,7 @@ def remove( bgcolor: Optional[Tuple[int, int, int, int]] = None, force_return_bytes: bool = False, *args: Optional[Any], - **kwargs: Optional[Any] + **kwargs: Optional[Any], ) -> Union[bytes, PILImage, np.ndarray]: """ Remove the background from an input image. diff --git a/rembg/commands/d_command.py b/rembg/commands/d_command.py index 45051cd..fe2f095 100644 --- a/rembg/commands/d_command.py +++ b/rembg/commands/d_command.py @@ -5,10 +5,11 @@ from ..bg import download_models @click.command( # type: ignore name="d", - help="download all models", + help="download models", ) -def d_command(*args, **kwargs) -> None: +@click.argument("models", nargs=-1) +def d_command(models: tuple[str, ...]) -> None: """ - Download all models + Download models """ - download_models() + download_models(models) diff --git a/rembg/sessions/__init__.py b/rembg/sessions/__init__.py index 7546f9b..dbf4fb1 100644 --- a/rembg/sessions/__init__.py +++ b/rembg/sessions/__init__.py @@ -1,93 +1,78 @@ from __future__ import annotations -from typing import List +from typing import Dict, List from .base import BaseSession -sessions_class: List[type[BaseSession]] = [] -sessions_names: List[str] = [] +sessions: Dict[str, type[BaseSession]] = {} from .birefnet_general import BiRefNetSessionGeneral -sessions_class.append(BiRefNetSessionGeneral) -sessions_names.append(BiRefNetSessionGeneral.name()) +sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral from .birefnet_general_lite import BiRefNetSessionGeneralLite -sessions_class.append(BiRefNetSessionGeneralLite) -sessions_names.append(BiRefNetSessionGeneralLite.name()) +sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite from .birefnet_portrait import BiRefNetSessionPortrait -sessions_class.append(BiRefNetSessionPortrait) -sessions_names.append(BiRefNetSessionPortrait.name()) +sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait from .birefnet_dis import BiRefNetSessionDIS -sessions_class.append(BiRefNetSessionDIS) -sessions_names.append(BiRefNetSessionDIS.name()) +sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS from .birefnet_hrsod import BiRefNetSessionHRSOD -sessions_class.append(BiRefNetSessionHRSOD) -sessions_names.append(BiRefNetSessionHRSOD.name()) +sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD from .birefnet_cod import BiRefNetSessionCOD -sessions_class.append(BiRefNetSessionCOD) -sessions_names.append(BiRefNetSessionCOD.name()) +sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD from .birefnet_massive import BiRefNetSessionMassive -sessions_class.append(BiRefNetSessionMassive) -sessions_names.append(BiRefNetSessionMassive.name()) +sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive from .dis_anime import DisSession -sessions_class.append(DisSession) -sessions_names.append(DisSession.name()) +sessions[DisSession.name()] = DisSession from .dis_general_use import DisSession as DisSessionGeneralUse -sessions_class.append(DisSessionGeneralUse) -sessions_names.append(DisSessionGeneralUse.name()) +sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse from .sam import SamSession -sessions_class.append(SamSession) -sessions_names.append(SamSession.name()) +sessions[SamSession.name()] = SamSession from .silueta import SiluetaSession -sessions_class.append(SiluetaSession) -sessions_names.append(SiluetaSession.name()) +sessions[SiluetaSession.name()] = SiluetaSession from .u2net_cloth_seg import Unet2ClothSession -sessions_class.append(Unet2ClothSession) -sessions_names.append(Unet2ClothSession.name()) +sessions[Unet2ClothSession.name()] = Unet2ClothSession from .u2net_custom import U2netCustomSession -sessions_class.append(U2netCustomSession) -sessions_names.append(U2netCustomSession.name()) +sessions[U2netCustomSession.name()] = U2netCustomSession from .u2net_human_seg import U2netHumanSegSession -sessions_class.append(U2netHumanSegSession) -sessions_names.append(U2netHumanSegSession.name()) +sessions[U2netHumanSegSession.name()] = U2netHumanSegSession from .u2net import U2netSession -sessions_class.append(U2netSession) -sessions_names.append(U2netSession.name()) +sessions[U2netSession.name()] = U2netSession from .u2netp import U2netpSession -sessions_class.append(U2netpSession) -sessions_names.append(U2netpSession.name()) +sessions[U2netpSession.name()] = U2netpSession from .bria_rmbg import BriaRmBgSession -sessions_class.append(BriaRmBgSession) -sessions_names.append(BriaRmBgSession.name()) +sessions[BriaRmBgSession.name()] = BriaRmBgSession + +sessions_names = list(sessions.keys()) +sessions_class = list(sessions.values())