Only download u2net model in docker image

With the addition of the extra models, the docker image is huge (> 17 GB)
and too big to build and push in GitHub actions, especially for `arm64`
(which is slower to build than `amd64`). This change makes it so the `d`
command can take a list of models to download, and the Dockerfile only
downloads the default `u2net` model. This shrinks the docker image to
about 3 GB.

More discussion at: https://github.com/danielgatis/rembg/discussions/735
This commit is contained in:
Paul Gross 2025-03-13 09:31:23 -07:00
parent b144cbc7cf
commit 4d24dded16
5 changed files with 46 additions and 49 deletions

View File

@ -8,7 +8,7 @@ on:
jobs:
publish_docker:
name: Push Docker image to Docker Hub
runs-on: ubuntu-latest
runs-on: ubuntu-24.04
steps:
- name: Checkout
uses: actions/checkout@v4
@ -43,7 +43,7 @@ jobs:
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64,linux/arm64

View File

@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/li
COPY . .
RUN python -m pip install ".[cpu,cli]"
RUN rembg d
RUN rembg d u2net
EXPOSE 7000
ENTRYPOINT ["rembg"]

View File

@ -1,4 +1,5 @@
import io
import sys
from enum import Enum
from typing import Any, List, Optional, Tuple, Union, cast
@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
from scipy.ndimage import binary_erosion
from .session_factory import new_session
from .sessions import sessions_class
from .sessions import sessions, sessions_names
from .sessions.base import BaseSession
ort.set_default_logger_severity(3)
@ -194,11 +195,21 @@ def fix_image_orientation(img: PILImage) -> PILImage:
return cast(PILImage, ImageOps.exif_transpose(img))
def download_models() -> None:
def download_models(models: tuple[str, ...]) -> None:
"""
Download models for image processing.
"""
for session in sessions_class:
if len(models) == 0:
print("No models specified, downloading all models")
models = tuple(sessions_names)
for model in models:
session = sessions.get(model)
if session is None:
print(f"Error: no model found: {model}")
sys.exit(1)
else:
print(f"Downloading model: {model}")
session.download_models()
@ -214,7 +225,7 @@ def remove(
bgcolor: Optional[Tuple[int, int, int, int]] = None,
force_return_bytes: bool = False,
*args: Optional[Any],
**kwargs: Optional[Any]
**kwargs: Optional[Any],
) -> Union[bytes, PILImage, np.ndarray]:
"""
Remove the background from an input image.

View File

@ -5,10 +5,11 @@ from ..bg import download_models
@click.command( # type: ignore
name="d",
help="download all models",
help="download models",
)
def d_command(*args, **kwargs) -> None:
@click.argument("models", nargs=-1)
def d_command(models: tuple[str, ...]) -> None:
"""
Download all models
Download models
"""
download_models()
download_models(models)

View File

@ -1,93 +1,78 @@
from __future__ import annotations
from typing import List
from typing import Dict, List
from .base import BaseSession
sessions_class: List[type[BaseSession]] = []
sessions_names: List[str] = []
sessions: Dict[str, type[BaseSession]] = {}
from .birefnet_general import BiRefNetSessionGeneral
sessions_class.append(BiRefNetSessionGeneral)
sessions_names.append(BiRefNetSessionGeneral.name())
sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
from .birefnet_general_lite import BiRefNetSessionGeneralLite
sessions_class.append(BiRefNetSessionGeneralLite)
sessions_names.append(BiRefNetSessionGeneralLite.name())
sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
from .birefnet_portrait import BiRefNetSessionPortrait
sessions_class.append(BiRefNetSessionPortrait)
sessions_names.append(BiRefNetSessionPortrait.name())
sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
from .birefnet_dis import BiRefNetSessionDIS
sessions_class.append(BiRefNetSessionDIS)
sessions_names.append(BiRefNetSessionDIS.name())
sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
from .birefnet_hrsod import BiRefNetSessionHRSOD
sessions_class.append(BiRefNetSessionHRSOD)
sessions_names.append(BiRefNetSessionHRSOD.name())
sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
from .birefnet_cod import BiRefNetSessionCOD
sessions_class.append(BiRefNetSessionCOD)
sessions_names.append(BiRefNetSessionCOD.name())
sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
from .birefnet_massive import BiRefNetSessionMassive
sessions_class.append(BiRefNetSessionMassive)
sessions_names.append(BiRefNetSessionMassive.name())
sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
from .dis_anime import DisSession
sessions_class.append(DisSession)
sessions_names.append(DisSession.name())
sessions[DisSession.name()] = DisSession
from .dis_general_use import DisSession as DisSessionGeneralUse
sessions_class.append(DisSessionGeneralUse)
sessions_names.append(DisSessionGeneralUse.name())
sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
from .sam import SamSession
sessions_class.append(SamSession)
sessions_names.append(SamSession.name())
sessions[SamSession.name()] = SamSession
from .silueta import SiluetaSession
sessions_class.append(SiluetaSession)
sessions_names.append(SiluetaSession.name())
sessions[SiluetaSession.name()] = SiluetaSession
from .u2net_cloth_seg import Unet2ClothSession
sessions_class.append(Unet2ClothSession)
sessions_names.append(Unet2ClothSession.name())
sessions[Unet2ClothSession.name()] = Unet2ClothSession
from .u2net_custom import U2netCustomSession
sessions_class.append(U2netCustomSession)
sessions_names.append(U2netCustomSession.name())
sessions[U2netCustomSession.name()] = U2netCustomSession
from .u2net_human_seg import U2netHumanSegSession
sessions_class.append(U2netHumanSegSession)
sessions_names.append(U2netHumanSegSession.name())
sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
from .u2net import U2netSession
sessions_class.append(U2netSession)
sessions_names.append(U2netSession.name())
sessions[U2netSession.name()] = U2netSession
from .u2netp import U2netpSession
sessions_class.append(U2netpSession)
sessions_names.append(U2netpSession.name())
sessions[U2netpSession.name()] = U2netpSession
from .bria_rmbg import BriaRmBgSession
sessions_class.append(BriaRmBgSession)
sessions_names.append(BriaRmBgSession.name())
sessions[BriaRmBgSession.name()] = BriaRmBgSession
sessions_names = list(sessions.keys())
sessions_class = list(sessions.values())