diff --git a/Dockerfile b/Dockerfile index 29b94e0..4017928 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,30 +1,11 @@ -FROM nvidia/cuda:11.6.0-runtime-ubuntu18.04 - -ENV DEBIAN_FRONTEND noninteractive - -RUN rm /etc/apt/sources.list.d/cuda.list || true -RUN rm /etc/apt/sources.list.d/nvidia-ml.list || true -RUN apt-key del 7fa2af80 -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub - -RUN apt update -y -RUN apt upgrade -y -RUN apt install -y curl wget software-properties-common -RUN add-apt-repository ppa:deadsnakes/ppa -RUN apt install -y python3.9 python3.9-distutils -RUN curl https://bootstrap.pypa.io/get-pip.py | python3.9 +FROM python:3.10-slim WORKDIR /rembg COPY . . -RUN python3.9 -m pip install .[gpu] - -RUN mkdir -p ~/.u2net -RUN wget https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx -O ~/.u2net/u2netp.onnx -RUN wget https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx -O ~/.u2net/u2net.onnx -RUN wget https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx -O ~/.u2net/u2net_human_seg.onnx -RUN wget https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx -O ~/.u2net/u2net_cloth_seg.onnx +RUN pip install --upgrade pip +RUN python -m pip install . +RUN python -c 'from rembg.bg import download_models; download_models()' EXPOSE 5000 ENTRYPOINT ["rembg"] diff --git a/rembg/bg.py b/rembg/bg.py index 414aa6d..b5ce82a 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -19,6 +19,7 @@ from pymatting.util.util import stack_images from scipy.ndimage import binary_erosion from .session_factory import new_session +from .sessions import sessions_class from .sessions.base import BaseSession kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) @@ -117,6 +118,11 @@ def fix_image_orientation(img: PILImage) -> PILImage: return ImageOps.exif_transpose(img) +def download_models() -> None: + for session in sessions_class: + session.download_models() + + def remove( data: Union[bytes, PILImage, np.ndarray], alpha_matting: bool = False,