mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-16 13:05:53 +08:00
Only download u2net model in docker image
With the addition of the extra models, the docker image is huge (> 17 GB) and too big to build and push in GitHub actions, especially for `arm64` (which is slower to build than `amd64`). This change makes it so the `d` command can take a list of models to download, and the Dockerfile only downloads the default `u2net` model. This shrinks the docker image to about 3 GB. More discussion at: https://github.com/danielgatis/rembg/discussions/735
This commit is contained in:
parent
b144cbc7cf
commit
4d24dded16
4
.github/workflows/publish_docker.yml
vendored
4
.github/workflows/publish_docker.yml
vendored
@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
publish_docker:
|
||||
name: Push Docker image to Docker Hub
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@ -43,7 +43,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
|
@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/li
|
||||
COPY . .
|
||||
|
||||
RUN python -m pip install ".[cpu,cli]"
|
||||
RUN rembg d
|
||||
RUN rembg d u2net
|
||||
|
||||
EXPOSE 7000
|
||||
ENTRYPOINT ["rembg"]
|
||||
|
19
rembg/bg.py
19
rembg/bg.py
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import sys
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Tuple, Union, cast
|
||||
|
||||
@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
|
||||
from scipy.ndimage import binary_erosion
|
||||
|
||||
from .session_factory import new_session
|
||||
from .sessions import sessions_class
|
||||
from .sessions import sessions, sessions_names
|
||||
from .sessions.base import BaseSession
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
@ -194,11 +195,21 @@ def fix_image_orientation(img: PILImage) -> PILImage:
|
||||
return cast(PILImage, ImageOps.exif_transpose(img))
|
||||
|
||||
|
||||
def download_models() -> None:
|
||||
def download_models(models: tuple[str, ...]) -> None:
|
||||
"""
|
||||
Download models for image processing.
|
||||
"""
|
||||
for session in sessions_class:
|
||||
if len(models) == 0:
|
||||
print("No models specified, downloading all models")
|
||||
models = tuple(sessions_names)
|
||||
|
||||
for model in models:
|
||||
session = sessions.get(model)
|
||||
if session is None:
|
||||
print(f"Error: no model found: {model}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"Downloading model: {model}")
|
||||
session.download_models()
|
||||
|
||||
|
||||
@ -214,7 +225,7 @@ def remove(
|
||||
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
||||
force_return_bytes: bool = False,
|
||||
*args: Optional[Any],
|
||||
**kwargs: Optional[Any]
|
||||
**kwargs: Optional[Any],
|
||||
) -> Union[bytes, PILImage, np.ndarray]:
|
||||
"""
|
||||
Remove the background from an input image.
|
||||
|
@ -5,10 +5,11 @@ from ..bg import download_models
|
||||
|
||||
@click.command( # type: ignore
|
||||
name="d",
|
||||
help="download all models",
|
||||
help="download models",
|
||||
)
|
||||
def d_command(*args, **kwargs) -> None:
|
||||
@click.argument("models", nargs=-1)
|
||||
def d_command(models: tuple[str, ...]) -> None:
|
||||
"""
|
||||
Download all models
|
||||
Download models
|
||||
"""
|
||||
download_models()
|
||||
download_models(models)
|
||||
|
@ -1,93 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
from .base import BaseSession
|
||||
|
||||
sessions_class: List[type[BaseSession]] = []
|
||||
sessions_names: List[str] = []
|
||||
sessions: Dict[str, type[BaseSession]] = {}
|
||||
|
||||
from .birefnet_general import BiRefNetSessionGeneral
|
||||
|
||||
sessions_class.append(BiRefNetSessionGeneral)
|
||||
sessions_names.append(BiRefNetSessionGeneral.name())
|
||||
sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
|
||||
|
||||
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
||||
|
||||
sessions_class.append(BiRefNetSessionGeneralLite)
|
||||
sessions_names.append(BiRefNetSessionGeneralLite.name())
|
||||
sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
|
||||
|
||||
from .birefnet_portrait import BiRefNetSessionPortrait
|
||||
|
||||
sessions_class.append(BiRefNetSessionPortrait)
|
||||
sessions_names.append(BiRefNetSessionPortrait.name())
|
||||
sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
|
||||
|
||||
from .birefnet_dis import BiRefNetSessionDIS
|
||||
|
||||
sessions_class.append(BiRefNetSessionDIS)
|
||||
sessions_names.append(BiRefNetSessionDIS.name())
|
||||
sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
|
||||
|
||||
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
||||
|
||||
sessions_class.append(BiRefNetSessionHRSOD)
|
||||
sessions_names.append(BiRefNetSessionHRSOD.name())
|
||||
sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
|
||||
|
||||
from .birefnet_cod import BiRefNetSessionCOD
|
||||
|
||||
sessions_class.append(BiRefNetSessionCOD)
|
||||
sessions_names.append(BiRefNetSessionCOD.name())
|
||||
sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
|
||||
|
||||
from .birefnet_massive import BiRefNetSessionMassive
|
||||
|
||||
sessions_class.append(BiRefNetSessionMassive)
|
||||
sessions_names.append(BiRefNetSessionMassive.name())
|
||||
sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
|
||||
|
||||
from .dis_anime import DisSession
|
||||
|
||||
sessions_class.append(DisSession)
|
||||
sessions_names.append(DisSession.name())
|
||||
sessions[DisSession.name()] = DisSession
|
||||
|
||||
from .dis_general_use import DisSession as DisSessionGeneralUse
|
||||
|
||||
sessions_class.append(DisSessionGeneralUse)
|
||||
sessions_names.append(DisSessionGeneralUse.name())
|
||||
sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
|
||||
|
||||
from .sam import SamSession
|
||||
|
||||
sessions_class.append(SamSession)
|
||||
sessions_names.append(SamSession.name())
|
||||
sessions[SamSession.name()] = SamSession
|
||||
|
||||
from .silueta import SiluetaSession
|
||||
|
||||
sessions_class.append(SiluetaSession)
|
||||
sessions_names.append(SiluetaSession.name())
|
||||
sessions[SiluetaSession.name()] = SiluetaSession
|
||||
|
||||
from .u2net_cloth_seg import Unet2ClothSession
|
||||
|
||||
sessions_class.append(Unet2ClothSession)
|
||||
sessions_names.append(Unet2ClothSession.name())
|
||||
sessions[Unet2ClothSession.name()] = Unet2ClothSession
|
||||
|
||||
from .u2net_custom import U2netCustomSession
|
||||
|
||||
sessions_class.append(U2netCustomSession)
|
||||
sessions_names.append(U2netCustomSession.name())
|
||||
sessions[U2netCustomSession.name()] = U2netCustomSession
|
||||
|
||||
from .u2net_human_seg import U2netHumanSegSession
|
||||
|
||||
sessions_class.append(U2netHumanSegSession)
|
||||
sessions_names.append(U2netHumanSegSession.name())
|
||||
sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
|
||||
|
||||
from .u2net import U2netSession
|
||||
|
||||
sessions_class.append(U2netSession)
|
||||
sessions_names.append(U2netSession.name())
|
||||
sessions[U2netSession.name()] = U2netSession
|
||||
|
||||
from .u2netp import U2netpSession
|
||||
|
||||
sessions_class.append(U2netpSession)
|
||||
sessions_names.append(U2netpSession.name())
|
||||
sessions[U2netpSession.name()] = U2netpSession
|
||||
|
||||
from .bria_rmbg import BriaRmBgSession
|
||||
|
||||
sessions_class.append(BriaRmBgSession)
|
||||
sessions_names.append(BriaRmBgSession.name())
|
||||
sessions[BriaRmBgSession.name()] = BriaRmBgSession
|
||||
|
||||
sessions_names = list(sessions.keys())
|
||||
sessions_class = list(sessions.values())
|
||||
|
Loading…
x
Reference in New Issue
Block a user