diff --git a/README.md b/README.md index feda4eb..74ad461 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,15 @@ rembg -o path/to/output.png path/to/input.png Remove the background from all images in a folder ```bash -rembg -p path/to/inputs +rembg -p path/to/inputs path/to/output +``` + +### Add a custom model + +Copy the `custom-model.pth` file to `~/.u2net` and run: + +```bash +curl -s http://input.png | rembg -m custom-model > output.png ``` ### Usage as a server diff --git a/setup.py b/setup.py index f732612..2525f12 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open("requirements.txt") as f: setup( name="rembg", - version="1.0.24", + version="1.0.25", description="Remove image background", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py index e51377d..6d13b82 100644 --- a/src/rembg/cmd/cli.py +++ b/src/rembg/cmd/cli.py @@ -10,6 +10,13 @@ from ..bg import remove def main(): + model_path = os.environ.get( + "U2NETP_PATH", + os.path.expanduser(os.path.join("~", ".u2net")), + ) + model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))] + + ap = argparse.ArgumentParser() ap.add_argument( @@ -17,7 +24,7 @@ def main(): "--model", default="u2net", type=str, - choices=("u2net", "u2net_human_seg", "u2netp"), + choices=model_choices, help="The model name.", ) @@ -66,8 +73,8 @@ def main(): ap.add_argument( "-p", "--path", - nargs="+", - help="Path of a file or a folder of files.", + nargs=2, + help="An input folder and an output folder.", ) ap.add_argument( @@ -94,15 +101,20 @@ def main(): if args.path: full_paths = [os.path.abspath(path) for path in args.path] + + input_paths = [full_paths[0]] + output_path = full_paths[1] + + if not os.path.exists(output_path): + os.makedirs(output_path) + files = set() - for path in full_paths: + for path in input_paths: if os.path.isfile(path): files.add(path) else: - full_paths += set(glob.glob(path + "/*")) - set( - glob.glob(path + "/*.out.png") - ) + input_paths += set(glob.glob(path + "/*")) for fi in tqdm(files): fi_type = filetype.guess(fi) @@ -113,7 +125,7 @@ def main(): continue with open(fi, "rb") as input: - with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output: + with open(os.path.join(output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"), "wb") as output: w( output, remove( diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py index d686ec9..6c9900b 100644 --- a/src/rembg/cmd/server.py +++ b/src/rembg/cmd/server.py @@ -1,3 +1,5 @@ +import os +import glob import argparse from io import BytesIO from urllib.parse import unquote_plus @@ -38,8 +40,14 @@ def index(): az = request.values.get("az", type=int, default=1000) model = request.args.get("model", type=str, default="u2net") - if model not in ("u2net", "u2net_human_seg", "u2netp"): - return {"error": "invalid query param 'model'"}, 400 + model_path = os.environ.get( + "U2NETP_PATH", + os.path.expanduser(os.path.join("~", ".u2net")), + ) + model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))] + + if model not in model_choices: + return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400 try: return send_file(