mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-18 05:55:53 +08:00
Only download u2net model in docker image
With the addition of the extra models, the docker image is huge (> 17 GB) and too big to build and push in GitHub actions, especially for `arm64` (which is slower to build than `amd64`). This change makes it so the `d` command can take a list of models to download, and the Dockerfile only downloads the default `u2net` model. This shrinks the docker image to about 3 GB. More discussion at: https://github.com/danielgatis/rembg/discussions/735
This commit is contained in:
parent
b144cbc7cf
commit
4d24dded16
4
.github/workflows/publish_docker.yml
vendored
4
.github/workflows/publish_docker.yml
vendored
@ -8,7 +8,7 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
publish_docker:
|
publish_docker:
|
||||||
name: Push Docker image to Docker Hub
|
name: Push Docker image to Docker Hub
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-24.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@ -43,7 +43,7 @@ jobs:
|
|||||||
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }}
|
||||||
|
|
||||||
- name: Build and push
|
- name: Build and push
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
|
@ -9,7 +9,7 @@ RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/li
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN python -m pip install ".[cpu,cli]"
|
RUN python -m pip install ".[cpu,cli]"
|
||||||
RUN rembg d
|
RUN rembg d u2net
|
||||||
|
|
||||||
EXPOSE 7000
|
EXPOSE 7000
|
||||||
ENTRYPOINT ["rembg"]
|
ENTRYPOINT ["rembg"]
|
||||||
|
19
rembg/bg.py
19
rembg/bg.py
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import sys
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Optional, Tuple, Union, cast
|
from typing import Any, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ from pymatting.util.util import stack_images
|
|||||||
from scipy.ndimage import binary_erosion
|
from scipy.ndimage import binary_erosion
|
||||||
|
|
||||||
from .session_factory import new_session
|
from .session_factory import new_session
|
||||||
from .sessions import sessions_class
|
from .sessions import sessions, sessions_names
|
||||||
from .sessions.base import BaseSession
|
from .sessions.base import BaseSession
|
||||||
|
|
||||||
ort.set_default_logger_severity(3)
|
ort.set_default_logger_severity(3)
|
||||||
@ -194,11 +195,21 @@ def fix_image_orientation(img: PILImage) -> PILImage:
|
|||||||
return cast(PILImage, ImageOps.exif_transpose(img))
|
return cast(PILImage, ImageOps.exif_transpose(img))
|
||||||
|
|
||||||
|
|
||||||
def download_models() -> None:
|
def download_models(models: tuple[str, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Download models for image processing.
|
Download models for image processing.
|
||||||
"""
|
"""
|
||||||
for session in sessions_class:
|
if len(models) == 0:
|
||||||
|
print("No models specified, downloading all models")
|
||||||
|
models = tuple(sessions_names)
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
session = sessions.get(model)
|
||||||
|
if session is None:
|
||||||
|
print(f"Error: no model found: {model}")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
print(f"Downloading model: {model}")
|
||||||
session.download_models()
|
session.download_models()
|
||||||
|
|
||||||
|
|
||||||
@ -214,7 +225,7 @@ def remove(
|
|||||||
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
||||||
force_return_bytes: bool = False,
|
force_return_bytes: bool = False,
|
||||||
*args: Optional[Any],
|
*args: Optional[Any],
|
||||||
**kwargs: Optional[Any]
|
**kwargs: Optional[Any],
|
||||||
) -> Union[bytes, PILImage, np.ndarray]:
|
) -> Union[bytes, PILImage, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Remove the background from an input image.
|
Remove the background from an input image.
|
||||||
|
@ -5,10 +5,11 @@ from ..bg import download_models
|
|||||||
|
|
||||||
@click.command( # type: ignore
|
@click.command( # type: ignore
|
||||||
name="d",
|
name="d",
|
||||||
help="download all models",
|
help="download models",
|
||||||
)
|
)
|
||||||
def d_command(*args, **kwargs) -> None:
|
@click.argument("models", nargs=-1)
|
||||||
|
def d_command(models: tuple[str, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Download all models
|
Download models
|
||||||
"""
|
"""
|
||||||
download_models()
|
download_models(models)
|
||||||
|
@ -1,93 +1,78 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
from .base import BaseSession
|
from .base import BaseSession
|
||||||
|
|
||||||
sessions_class: List[type[BaseSession]] = []
|
sessions: Dict[str, type[BaseSession]] = {}
|
||||||
sessions_names: List[str] = []
|
|
||||||
|
|
||||||
from .birefnet_general import BiRefNetSessionGeneral
|
from .birefnet_general import BiRefNetSessionGeneral
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionGeneral)
|
sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral
|
||||||
sessions_names.append(BiRefNetSessionGeneral.name())
|
|
||||||
|
|
||||||
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
from .birefnet_general_lite import BiRefNetSessionGeneralLite
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionGeneralLite)
|
sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite
|
||||||
sessions_names.append(BiRefNetSessionGeneralLite.name())
|
|
||||||
|
|
||||||
from .birefnet_portrait import BiRefNetSessionPortrait
|
from .birefnet_portrait import BiRefNetSessionPortrait
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionPortrait)
|
sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait
|
||||||
sessions_names.append(BiRefNetSessionPortrait.name())
|
|
||||||
|
|
||||||
from .birefnet_dis import BiRefNetSessionDIS
|
from .birefnet_dis import BiRefNetSessionDIS
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionDIS)
|
sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS
|
||||||
sessions_names.append(BiRefNetSessionDIS.name())
|
|
||||||
|
|
||||||
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
from .birefnet_hrsod import BiRefNetSessionHRSOD
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionHRSOD)
|
sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD
|
||||||
sessions_names.append(BiRefNetSessionHRSOD.name())
|
|
||||||
|
|
||||||
from .birefnet_cod import BiRefNetSessionCOD
|
from .birefnet_cod import BiRefNetSessionCOD
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionCOD)
|
sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD
|
||||||
sessions_names.append(BiRefNetSessionCOD.name())
|
|
||||||
|
|
||||||
from .birefnet_massive import BiRefNetSessionMassive
|
from .birefnet_massive import BiRefNetSessionMassive
|
||||||
|
|
||||||
sessions_class.append(BiRefNetSessionMassive)
|
sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive
|
||||||
sessions_names.append(BiRefNetSessionMassive.name())
|
|
||||||
|
|
||||||
from .dis_anime import DisSession
|
from .dis_anime import DisSession
|
||||||
|
|
||||||
sessions_class.append(DisSession)
|
sessions[DisSession.name()] = DisSession
|
||||||
sessions_names.append(DisSession.name())
|
|
||||||
|
|
||||||
from .dis_general_use import DisSession as DisSessionGeneralUse
|
from .dis_general_use import DisSession as DisSessionGeneralUse
|
||||||
|
|
||||||
sessions_class.append(DisSessionGeneralUse)
|
sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse
|
||||||
sessions_names.append(DisSessionGeneralUse.name())
|
|
||||||
|
|
||||||
from .sam import SamSession
|
from .sam import SamSession
|
||||||
|
|
||||||
sessions_class.append(SamSession)
|
sessions[SamSession.name()] = SamSession
|
||||||
sessions_names.append(SamSession.name())
|
|
||||||
|
|
||||||
from .silueta import SiluetaSession
|
from .silueta import SiluetaSession
|
||||||
|
|
||||||
sessions_class.append(SiluetaSession)
|
sessions[SiluetaSession.name()] = SiluetaSession
|
||||||
sessions_names.append(SiluetaSession.name())
|
|
||||||
|
|
||||||
from .u2net_cloth_seg import Unet2ClothSession
|
from .u2net_cloth_seg import Unet2ClothSession
|
||||||
|
|
||||||
sessions_class.append(Unet2ClothSession)
|
sessions[Unet2ClothSession.name()] = Unet2ClothSession
|
||||||
sessions_names.append(Unet2ClothSession.name())
|
|
||||||
|
|
||||||
from .u2net_custom import U2netCustomSession
|
from .u2net_custom import U2netCustomSession
|
||||||
|
|
||||||
sessions_class.append(U2netCustomSession)
|
sessions[U2netCustomSession.name()] = U2netCustomSession
|
||||||
sessions_names.append(U2netCustomSession.name())
|
|
||||||
|
|
||||||
from .u2net_human_seg import U2netHumanSegSession
|
from .u2net_human_seg import U2netHumanSegSession
|
||||||
|
|
||||||
sessions_class.append(U2netHumanSegSession)
|
sessions[U2netHumanSegSession.name()] = U2netHumanSegSession
|
||||||
sessions_names.append(U2netHumanSegSession.name())
|
|
||||||
|
|
||||||
from .u2net import U2netSession
|
from .u2net import U2netSession
|
||||||
|
|
||||||
sessions_class.append(U2netSession)
|
sessions[U2netSession.name()] = U2netSession
|
||||||
sessions_names.append(U2netSession.name())
|
|
||||||
|
|
||||||
from .u2netp import U2netpSession
|
from .u2netp import U2netpSession
|
||||||
|
|
||||||
sessions_class.append(U2netpSession)
|
sessions[U2netpSession.name()] = U2netpSession
|
||||||
sessions_names.append(U2netpSession.name())
|
|
||||||
|
|
||||||
from .bria_rmbg import BriaRmBgSession
|
from .bria_rmbg import BriaRmBgSession
|
||||||
|
|
||||||
sessions_class.append(BriaRmBgSession)
|
sessions[BriaRmBgSession.name()] = BriaRmBgSession
|
||||||
sessions_names.append(BriaRmBgSession.name())
|
|
||||||
|
sessions_names = list(sessions.keys())
|
||||||
|
sessions_class = list(sessions.values())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user