diff --git a/Dockerfile b/Dockerfile index 56936f8..dfd8596 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ WORKDIR /rembg COPY . . -RUN GPU=1 pip3 install . +RUN ["pip3", "install", ".[gpu]"] ENTRYPOINT ["rembg"] CMD [] diff --git a/README.md b/README.md index e1561ec..88bdb12 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ pip install rembg GPU support: ```bash -GPU=1 pip install rembg +pip install rembg[gpu] ``` ### Usage as a cli diff --git a/rembg/bg.py b/rembg/bg.py index ade45ad..b494670 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -18,13 +18,7 @@ def alpha_matting_cutout( foreground_threshold: int, background_threshold: int, erode_structure_size: int, - base_size: int, ) -> Image: - size = img.size - - img.thumbnail((base_size, base_size), Image.LANCZOS) - mask = mask.resize(img.size, Image.LANCZOS) - img = np.asarray(img) mask = np.asarray(mask) @@ -60,45 +54,37 @@ def alpha_matting_cutout( cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) cutout = Image.fromarray(cutout) - cutout = cutout.resize(size, Image.LANCZOS) return cutout def naive_cutout(img: Image, mask: Image) -> Image: empty = Image.new("RGBA", (img.size), 0) - cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS)) + cutout = Image.composite(img, empty, mask) return cutout -def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Image: - original_width, original_height = img.size - width = original_width if width is None else width - height = original_height if height is None else height - return img.resize((width, height)) - - def remove( data: bytes, alpha_matting: bool = False, alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, alpha_matting_erode_size: int = 10, - alpha_matting_base_size: int = 1000, session: Optional[ort.InferenceSession] = None, - width: Optional[int] = None, - height: Optional[int] = None, + only_mask: bool = False, ) -> bytes: img = Image.open(io.BytesIO(data)).convert("RGB") - if width is not None or height is not None: - img = resize_image(img, width, height) if session is None: session = ort_session("u2net") mask = predict(session, np.array(img)).convert("L") + mask = mask.resize(img.size, Image.LANCZOS) - if alpha_matting: + if only_mask: + cutout = mask + + elif alpha_matting: try: cutout = alpha_matting_cutout( img, @@ -106,7 +92,6 @@ def remove( alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size, - alpha_matting_base_size, ) except Exception: cutout = naive_cutout(img, mask) diff --git a/rembg/cli.py b/rembg/cli.py index 40c82e1..629e9cf 100644 --- a/rembg/cli.py +++ b/rembg/cli.py @@ -66,28 +66,11 @@ def main(): help="erode size", ) @click.option( - "-az", - "--alpha-matting-base-size", - default=1000, - type=int, + "-om", + "--only-mask", + is_flag=True, show_default=True, - help="image base size", -) -@click.option( - "-w", - "--width", - default=None, - type=int, - show_default=True, - help="output image size", -) -@click.option( - "-h", - "--height", - default=None, - type=int, - show_default=True, - help="output image size", + help="output only the mask", ) @click.argument( "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb") @@ -143,28 +126,11 @@ def i(model: str, input: IO, output: IO, **kwargs): help="erode size", ) @click.option( - "-az", - "--alpha-matting-base-size", - default=1000, - type=int, + "-om", + "--only-mask", + is_flag=True, show_default=True, - help="image base size", -) -@click.option( - "-w", - "--width", - default=None, - type=int, - show_default=True, - help="output image size", -) -@click.option( - "-h", - "--height", - default=None, - type=int, - show_default=True, - help="output image size", + help="output only the mask", ) @click.argument( "input", @@ -240,18 +206,14 @@ def s(port: int, log_level: str): af: int = Query(240, ge=0), ab: int = Query(10, ge=0), ae: int = Query(10, ge=0), - az: int = Query(1000, ge=0), - width: Optional[int] = Query(None, gt=0), - height: Optional[int] = Query(None, gt=0), + om: bool = Query(False), ): self.model = model - self.width = width - self.height = height self.a = a self.af = af self.ab = ab self.ae = ae - self.az = az + self.om = om def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: return Response( @@ -260,13 +222,11 @@ def s(port: int, log_level: str): session=sessions.setdefault( commons.model.value, ort_session(commons.model.value) ), - width=commons.width, - height=commons.height, alpha_matting=commons.a, alpha_matting_foreground_threshold=commons.af, alpha_matting_background_threshold=commons.ab, alpha_matting_erode_size=commons.ae, - alpha_matting_base_size=commons.az, + only_mask=commons.om, ), media_type="image/png", ) diff --git a/requirements-cpu.txt b/requirements-cpu.txt deleted file mode 100644 index 7be989e..0000000 --- a/requirements-cpu.txt +++ /dev/null @@ -1 +0,0 @@ -onnxruntime==1.10.0 diff --git a/requirements.txt b/requirements.txt index 7e673fd..a9a365b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ fastapi==0.72.0 filetype==1.0.9 gdown==4.2.0 numpy==1.21.5 +onnxruntime==1.10.0 pillow==9.0.0 pymatting==1.1.5 python-multipart==0.0.5 diff --git a/setup.py b/setup.py index 3a48a82..2a3458f 100644 --- a/setup.py +++ b/setup.py @@ -14,12 +14,8 @@ long_description = (here / "README.md").read_text(encoding="utf-8") with open("requirements.txt") as f: requireds = f.read().splitlines() -if os.getenv("GPU") is None: - with open("requirements-cpu.txt") as f: - requireds += f.read().splitlines() -else: - with open("requirements-gpu.txt") as f: - requireds += f.read().splitlines() +with open("requirements-gpu.txt") as f: + gpu_requireds = f.read().splitlines() setup( name="rembg", @@ -42,6 +38,9 @@ setup( "rembg=rembg.cli:main", ], }, + extras_require={ + 'gpu': gpu_requireds, + }, version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), )