Only download u2net model in docker image

With the addition of the extra models, the docker image is huge (> 17 GB)
and too big to build and push in GitHub actions, especially for `arm64`
(which is slower to build than `amd64`). This change makes it so the `d`
command can take a list of models to download, and the Dockerfile only
downloads the default `u2net` model. This shrinks the docker image to
about 3 GB.

More discussion at: https://github.com/danielgatis/rembg/discussions/735
This commit is contained in:
Paul Gross 2025-03-13 09:31:23 -07:00
parent b144cbc7cf
commit 4d24dded16
5 changed files with 46 additions and 49 deletions

View File

@ -8,7 +8,7 @@ on:
jobs: jobs:
publish_docker: publish_docker:
name: Push Docker image to Docker Hub name: Push Docker image to Docker Hub
runs-on: ubuntu-latest runs-on: ubuntu-24.04
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -43,7 +43,7 @@ jobs:
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
- name: Build and push - name: Build and push
uses: docker/build-push-action@v5 uses: docker/build-push-action@v6
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64

View File

@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/li
COPY . . COPY . .
RUN python -m pip install ".[cpu,cli]" RUN python -m pip install ".[cpu,cli]"
RUN rembg d RUN rembg d u2net
EXPOSE 7000 EXPOSE 7000
ENTRYPOINT ["rembg"] ENTRYPOINT ["rembg"]

View File

@ -1,4 +1,5 @@
import io import io
import sys
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Tuple, Union, cast from typing import Any, List, Optional, Tuple, Union, cast
@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
from scipy.ndimage import binary_erosion from scipy.ndimage import binary_erosion
from .session_factory import new_session from .session_factory import new_session
from .sessions import sessions_class from .sessions import sessions, sessions_names
from .sessions.base import BaseSession from .sessions.base import BaseSession
ort.set_default_logger_severity(3) ort.set_default_logger_severity(3)
@ -194,12 +195,22 @@ def fix_image_orientation(img: PILImage) -> PILImage:
return cast(PILImage, ImageOps.exif_transpose(img)) return cast(PILImage, ImageOps.exif_transpose(img))
def download_models() -> None: def download_models(models: tuple[str, ...]) -> None:
""" """
Download models for image processing. Download models for image processing.
""" """
for session in sessions_class: if len(models) == 0:
session.download_models() print("No models specified, downloading all models")
models = tuple(sessions_names)
for model in models:
session = sessions.get(model)
if session is None:
print(f"Error: no model found: {model}")
sys.exit(1)
else:
print(f"Downloading model: {model}")
session.download_models()
def remove( def remove(
@ -214,7 +225,7 @@ def remove(
bgcolor: Optional[Tuple[int, int, int, int]] = None, bgcolor: Optional[Tuple[int, int, int, int]] = None,
force_return_bytes: bool = False, force_return_bytes: bool = False,
*args: Optional[Any], *args: Optional[Any],
**kwargs: Optional[Any] **kwargs: Optional[Any],
) -> Union[bytes, PILImage, np.ndarray]: ) -> Union[bytes, PILImage, np.ndarray]:
""" """
Remove the background from an input image. Remove the background from an input image.

View File

@ -5,10 +5,11 @@ from ..bg import download_models
@click.command( # type: ignore @click.command( # type: ignore
name="d", name="d",
help="download all models", help="download models",
) )
def d_command(*args, **kwargs) -> None: @click.argument("models", nargs=-1)
def d_command(models: tuple[str, ...]) -> None:
""" """
Download all models Download models
""" """
download_models() download_models(models)

View File

@ -1,93 +1,78 @@
from __future__ import annotations from __future__ import annotations
from typing import List from typing import Dict, List
from .base import BaseSession from .base import BaseSession
sessions_class: List[type[BaseSession]] = [] sessions: Dict[str, type[BaseSession]] = {}
sessions_names: List[str] = []
from .birefnet_general import BiRefNetSessionGeneral from .birefnet_general import BiRefNetSessionGeneral
sessions_class.append(BiRefNetSessionGeneral) sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
sessions_names.append(BiRefNetSessionGeneral.name())
from .birefnet_general_lite import BiRefNetSessionGeneralLite from .birefnet_general_lite import BiRefNetSessionGeneralLite
sessions_class.append(BiRefNetSessionGeneralLite) sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
sessions_names.append(BiRefNetSessionGeneralLite.name())
from .birefnet_portrait import BiRefNetSessionPortrait from .birefnet_portrait import BiRefNetSessionPortrait
sessions_class.append(BiRefNetSessionPortrait) sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
sessions_names.append(BiRefNetSessionPortrait.name())
from .birefnet_dis import BiRefNetSessionDIS from .birefnet_dis import BiRefNetSessionDIS
sessions_class.append(BiRefNetSessionDIS) sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
sessions_names.append(BiRefNetSessionDIS.name())
from .birefnet_hrsod import BiRefNetSessionHRSOD from .birefnet_hrsod import BiRefNetSessionHRSOD
sessions_class.append(BiRefNetSessionHRSOD) sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
sessions_names.append(BiRefNetSessionHRSOD.name())
from .birefnet_cod import BiRefNetSessionCOD from .birefnet_cod import BiRefNetSessionCOD
sessions_class.append(BiRefNetSessionCOD) sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
sessions_names.append(BiRefNetSessionCOD.name())
from .birefnet_massive import BiRefNetSessionMassive from .birefnet_massive import BiRefNetSessionMassive
sessions_class.append(BiRefNetSessionMassive) sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
sessions_names.append(BiRefNetSessionMassive.name())
from .dis_anime import DisSession from .dis_anime import DisSession
sessions_class.append(DisSession) sessions[DisSession.name()] = DisSession
sessions_names.append(DisSession.name())
from .dis_general_use import DisSession as DisSessionGeneralUse from .dis_general_use import DisSession as DisSessionGeneralUse
sessions_class.append(DisSessionGeneralUse) sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
sessions_names.append(DisSessionGeneralUse.name())
from .sam import SamSession from .sam import SamSession
sessions_class.append(SamSession) sessions[SamSession.name()] = SamSession
sessions_names.append(SamSession.name())
from .silueta import SiluetaSession from .silueta import SiluetaSession
sessions_class.append(SiluetaSession) sessions[SiluetaSession.name()] = SiluetaSession
sessions_names.append(SiluetaSession.name())
from .u2net_cloth_seg import Unet2ClothSession from .u2net_cloth_seg import Unet2ClothSession
sessions_class.append(Unet2ClothSession) sessions[Unet2ClothSession.name()] = Unet2ClothSession
sessions_names.append(Unet2ClothSession.name())
from .u2net_custom import U2netCustomSession from .u2net_custom import U2netCustomSession
sessions_class.append(U2netCustomSession) sessions[U2netCustomSession.name()] = U2netCustomSession
sessions_names.append(U2netCustomSession.name())
from .u2net_human_seg import U2netHumanSegSession from .u2net_human_seg import U2netHumanSegSession
sessions_class.append(U2netHumanSegSession) sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
sessions_names.append(U2netHumanSegSession.name())
from .u2net import U2netSession from .u2net import U2netSession
sessions_class.append(U2netSession) sessions[U2netSession.name()] = U2netSession
sessions_names.append(U2netSession.name())
from .u2netp import U2netpSession from .u2netp import U2netpSession
sessions_class.append(U2netpSession) sessions[U2netpSession.name()] = U2netpSession
sessions_names.append(U2netpSession.name())
from .bria_rmbg import BriaRmBgSession from .bria_rmbg import BriaRmBgSession
sessions_class.append(BriaRmBgSession) sessions[BriaRmBgSession.name()] = BriaRmBgSession
sessions_names.append(BriaRmBgSession.name())
sessions_names = list(sessions.keys())
sessions_class = list(sessions.values())