diff --git a/README.md b/README.md index 2dfd393..99cf851 100644 --- a/README.md +++ b/README.md @@ -53,24 +53,24 @@ GPU=1 pip install rembg Remove the background from a remote image ```bash -curl -s http://input.png | rembg > output.png +curl -s http://input.png | rembg i > output.png ``` Remove the background from a local file ```bash -rembg -o path/to/output.png path/to/input.png +rembg i path/to/input.png path/to/output.png ``` Remove the background from all images in a folder ```bash -rembg -p path/to/input path/to/output +rembg p path/to/input path/to/output ``` ### Usage as a server Start the server ```bash -rembg-server +rembg s ``` Open your browser to @@ -140,14 +140,14 @@ docker build . -t rembg Then run with: ``` -docker run --rm -i rembg out.png +docker run --rm -i rembg i in.png out.png ``` ### Advance usage Sometimes it is possible to achieve better results by turning on alpha matting. Example: ```bash -curl -s http://input.png | rembg -a -ae 15 > output.png +curl -s http://input.png | rembg i -a -ae 15 > output.png ``` diff --git a/rembg/bg.py b/rembg/bg.py index c584829..ade45ad 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -75,11 +75,7 @@ def resize_image(img: Image, width: Optional[int], height: Optional[int]) -> Ima original_width, original_height = img.size width = original_width if width is None else width height = original_height if height is None else height - return ( - img.resize((width, height)) - if original_width != width or original_height != height - else img - ) + return img.resize((width, height)) def remove( @@ -87,7 +83,7 @@ def remove( alpha_matting: bool = False, alpha_matting_foreground_threshold: int = 240, alpha_matting_background_threshold: int = 10, - alpha_matting_erode_structure_size: int = 10, + alpha_matting_erode_size: int = 10, alpha_matting_base_size: int = 1000, session: Optional[ort.InferenceSession] = None, width: Optional[int] = None, @@ -109,7 +105,7 @@ def remove( mask, alpha_matting_foreground_threshold, alpha_matting_background_threshold, - alpha_matting_erode_structure_size, + alpha_matting_erode_size, alpha_matting_base_size, ) except Exception: diff --git a/rembg/cli.py b/rembg/cli.py index 28c0709..a637c47 100644 --- a/rembg/cli.py +++ b/rembg/cli.py @@ -1,156 +1,284 @@ -import argparse -import glob -import os -from distutils.util import strtobool -from typing import BinaryIO +import pathlib import sys -from pathlib import Path +from enum import Enum +from typing import IO, Optional +import click import filetype -from tqdm import tqdm import onnxruntime as ort +import requests +import uvicorn +from fastapi import Depends, FastAPI, File, Query +from starlette.responses import Response +from tqdm import tqdm from .bg import remove from .detect import ort_session -sessions: dict[str, ort.InferenceSession] = {} - +@click.group() +@click.version_option() def main(): - ap = argparse.ArgumentParser() + pass - ap.add_argument( - "-m", - "--model", - default="u2net", - type=str, - choices=["u2net", "u2netp", "u2net_human_seg"], - help="The model name.", - ) - ap.add_argument( - "-a", - "--alpha-matting", - nargs="?", - const=True, - default=False, - type=lambda x: bool(strtobool(x)), - help="When true use alpha matting cutout.", - ) +@main.command(help="for a file as input") +@click.option( + "-m", + "--model", + default="u2net", + type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]), + show_default=True, + show_choices=True, + help="model name", +) +@click.option( + "-a", + "--alpha-matting", + is_flag=True, + show_default=True, + help="use alpha matting", +) +@click.option( + "-af", + "--alpha-matting-foreground-threshold", + default=240, + type=int, + show_default=True, + help="trimap fg threshold", +) +@click.option( + "-ab", + "--alpha-matting-background-threshold", + default=10, + type=int, + show_default=True, + help="trimap bg threshold", +) +@click.option( + "-ae", + "--alpha-matting-erode-size", + default=10, + type=int, + show_default=True, + help="erode size", +) +@click.option( + "-az", + "--alpha-matting-base-size", + default=1000, + type=int, + show_default=True, + help="image base size", +) +@click.option( + "-w", + "--width", + default=None, + type=int, + show_default=True, + help="output image size", +) +@click.option( + "-h", + "--height", + default=None, + type=int, + show_default=True, + help="output image size", +) +@click.argument( + "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb") +) +@click.argument( + "output", + default=(None if sys.stdin.isatty() else "-"), + type=click.File("wb", lazy=True), +) +def i(model: str, input: IO, output: IO, **kwargs: dict): + output.write(remove(input.read(), session=ort_session(model), **kwargs)) - ap.add_argument( - "-af", - "--alpha-matting-foreground-threshold", - default=240, - type=int, - help="The trimap foreground threshold.", - ) - ap.add_argument( - "-ab", - "--alpha-matting-background-threshold", - default=10, - type=int, - help="The trimap background threshold.", - ) +@main.command(help="for a folder as input") +@click.option( + "-m", + "--model", + default="u2net", + type=click.Choice(["u2net", "u2netp", "u2net_human_seg"]), + show_default=True, + show_choices=True, + help="model name", +) +@click.option( + "-a", + "--alpha-matting", + is_flag=True, + show_default=True, + help="use alpha matting", +) +@click.option( + "-af", + "--alpha-matting-foreground-threshold", + default=240, + type=int, + show_default=True, + help="trimap fg threshold", +) +@click.option( + "-ab", + "--alpha-matting-background-threshold", + default=10, + type=int, + show_default=True, + help="trimap bg threshold", +) +@click.option( + "-ae", + "--alpha-matting-erode-size", + default=10, + type=int, + show_default=True, + help="erode size", +) +@click.option( + "-az", + "--alpha-matting-base-size", + default=1000, + type=int, + show_default=True, + help="image base size", +) +@click.option( + "-w", + "--width", + default=None, + type=int, + show_default=True, + help="output image size", +) +@click.option( + "-h", + "--height", + default=None, + type=int, + show_default=True, + help="output image size", +) +@click.argument( + "input", + type=click.Path( + exists=True, + path_type=pathlib.Path, + file_okay=False, + dir_okay=True, + readable=True, + ), +) +@click.argument( + "output", + type=click.Path( + exists=False, + path_type=pathlib.Path, + file_okay=False, + dir_okay=True, + writable=True, + ), +) +def p(model: str, input: pathlib.Path, output: pathlib.Path, **kwargs: dict): + session = ort_session(model) + for each_input in tqdm(list(input.glob("**/*"))): + if each_input.is_dir(): + continue - ap.add_argument( - "-ae", - "--alpha-matting-erode-size", - default=10, - type=int, - help="Size of element used for the erosion.", - ) + mimetype = filetype.guess(each_input) + if mimetype is None: + continue + if mimetype.mime.find("image") < 0: + continue - ap.add_argument( - "-az", - "--alpha-matting-base-size", - default=1000, - type=int, - help="The image base size.", - ) + each_output = (output / each_input.name).with_suffix(".png") + each_output.parents[0].mkdir(parents=True, exist_ok=True) - ap.add_argument( - "-p", - "--path", - nargs=2, - help="An input folder and an output folder.", - ) - - ap.add_argument( - "input", - nargs=(None if sys.stdin.isatty() else "?"), - default=(None if sys.stdin.isatty() else sys.stdin.buffer), - type=argparse.FileType("rb"), - help="Path to the input image.", - ) - - ap.add_argument( - "output", - nargs=(None if sys.stdin.isatty() else "?"), - default=(None if sys.stdin.isatty() else sys.stdout.buffer), - type=argparse.FileType("wb"), - help="Path to the output png image.", - ) - - args = ap.parse_args() - session = sessions.setdefault(args.model, ort_session(args.model)) - - if args.path: - full_paths = [os.path.abspath(path) for path in args.path] - - input_paths = [full_paths[0]] - output_path = full_paths[1] - - if not os.path.exists(output_path): - os.makedirs(output_path) - - input_files = set() - - for input_path in input_paths: - if os.path.isfile(path): - input_files.add(path) - else: - input_paths += set(glob.glob(input_path + "/*")) - - for input_file in tqdm(input_files): - input_file_type = filetype.guess(input_file) - - if input_file_type is None: - continue - - if input_file_type.mime.find("image") < 0: - continue - - out_file = os.path.join( - output_path, os.path.splitext(os.path.basename(input_file))[0] + ".png" - ) - - Path(out_file).write_bytes( - remove( - Path(input_file).read_bytes(), - session=session, - alpha_matting=args.alpha_matting, - alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold, - alpha_matting_background_threshold=args.alpha_matting_background_threshold, - alpha_matting_erode_structure_size=args.alpha_matting_erode_size, - alpha_matting_base_size=args.alpha_matting_base_size, - ) - ) - - else: - args.output.write( - remove( - args.input.read(), - session=session, - alpha_matting=args.alpha_matting, - alpha_matting_foreground_threshold=args.alpha_matting_foreground_threshold, - alpha_matting_background_threshold=args.alpha_matting_background_threshold, - alpha_matting_erode_structure_size=args.alpha_matting_erode_size, - alpha_matting_base_size=args.alpha_matting_base_size, - ) + each_output.write_bytes( + remove(each_input.read_bytes(), session=session, **kwargs) ) +@main.command(help="for a http server") +@click.option( + "-p", + "--port", + default=5000, + type=int, + show_default=True, + help="port", +) +@click.option( + "-l", + "--log_level", + default="info", + type=str, + show_default=True, + help="log level", +) +def s(port: int, log_level: str): + sessions: dict[str, ort.InferenceSession] = {} + app = FastAPI() + + class ModelType(str, Enum): + u2net = "u2net" + u2netp = "u2netp" + u2net_human_seg = "u2net_human_seg" + + class CommonQueryParams: + def __init__( + self, + model: ModelType = Query(ModelType.u2net), + a: bool = Query(False), + af: int = Query(240, ge=0), + ab: int = Query(10, ge=0), + ae: int = Query(10, ge=0), + az: int = Query(1000, ge=0), + width: Optional[int] = Query(None, gt=0), + height: Optional[int] = Query(None, gt=0), + ): + self.model = model + self.width = width + self.height = height + self.a = a + self.af = af + self.ab = ab + self.ae = ae + self.az = az + + def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: + return Response( + remove( + content, + session=sessions.setdefault( + commons.model.value, ort_session(commons.model.value) + ), + width=commons.width, + height=commons.height, + alpha_matting=commons.a, + alpha_matting_foreground_threshold=commons.af, + alpha_matting_background_threshold=commons.ab, + alpha_matting_erode_size=commons.ae, + alpha_matting_base_size=commons.az, + ), + media_type="image/png", + ) + + @app.get("/") + def get_index(url: str, commons: CommonQueryParams = Depends()): + return im_without_bg(requests.get(url).content, commons) + + @app.post("/") + def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()): + return im_without_bg(file, commons) + + uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level) + + if __name__ == "__main__": main() diff --git a/rembg/server.py b/rembg/server.py deleted file mode 100644 index 551e9d7..0000000 --- a/rembg/server.py +++ /dev/null @@ -1,110 +0,0 @@ -import argparse -from enum import Enum -from typing import Optional - -import requests -import uvicorn -from fastapi import Depends, FastAPI, File, Form, Query, UploadFile -from PIL import Image -from starlette.responses import Response -import onnxruntime as ort - -from .bg import remove -from .detect import ort_session - -sessions: dict[str, ort.InferenceSession] = {} -app = FastAPI() - - -class ModelType(str, Enum): - u2net = "u2net" - u2netp = "u2netp" - u2net_human_seg = "u2net_human_seg" - - -class CommonQueryParams: - def __init__( - self, - model: ModelType = Query(ModelType.u2net), - a: bool = Query(False), - af: int = Query(240, ge=0), - ab: int = Query(10, ge=0), - ae: int = Query(10, ge=0), - az: int = Query(1000, ge=0), - width: Optional[int] = Query(None, gt=0), - height: Optional[int] = Query(None, gt=0), - ): - self.model = model - self.width = width - self.height = height - self.a = a - self.af = af - self.ab = ab - self.ae = ae - self.az = az - - -def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: - return Response( - remove( - content, - session=sessions.setdefault( - commons.model.value, ort_session(commons.model.value) - ), - width=commons.width, - height=commons.height, - alpha_matting=commons.a, - alpha_matting_foreground_threshold=commons.af, - alpha_matting_background_threshold=commons.ab, - alpha_matting_erode_structure_size=commons.ae, - alpha_matting_base_size=commons.az, - ), - media_type="image/png", - ) - - -@app.get("/") -def get_index(url: str, commons: CommonQueryParams = Depends()): - return im_without_bg(requests.get(url).content, commons) - - -@app.post("/") -def post_index(file: bytes = File(...), commons: CommonQueryParams = Depends()): - return im_without_bg(file, commons) - - -def main(): - ap = argparse.ArgumentParser() - - ap.add_argument( - "-a", - "--addr", - default="0.0.0.0", - type=str, - help="The IP address to bind to.", - ) - - ap.add_argument( - "-p", - "--port", - default=5000, - type=int, - help="The port to bind to.", - ) - - ap.add_argument( - "-l", - "--log_level", - default="info", - type=str, - help="The log level.", - ) - - args = ap.parse_args() - uvicorn.run( - "rembg.server:app", host=args.addr, port=args.port, log_level=args.log_level - ) - - -if __name__ == "__main__": - main() diff --git a/requirements.txt b/requirements.txt index a383fbc..96c4728 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ scikit-image==0.19.1 scipy==1.7.3 tqdm==4.62.3 uvicorn==0.17.0 +click==8.0.3 diff --git a/setup.py b/setup.py index 54dedf1..3a48a82 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,6 @@ setup( entry_points={ "console_scripts": [ "rembg=rembg.cli:main", - "rembg-server=rembg.server:main", ], }, version=versioneer.get_version(),