This commit is contained in:
Daniel Gatis 2021-03-23 17:42:35 -03:00
parent 07a46d478a
commit ca4be25f90
4 changed files with 40 additions and 12 deletions

View File

@ -80,7 +80,15 @@ rembg -o path/to/output.png path/to/input.png
Remove the background from all images in a folder Remove the background from all images in a folder
```bash ```bash
rembg -p path/to/inputs rembg -p path/to/inputs path/to/output
```
### Add a custom model
Copy the `custom-model.pth` file to `~/.u2net` and run:
```bash
curl -s http://input.png | rembg -m custom-model > output.png
``` ```
### Usage as a server ### Usage as a server

View File

@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup( setup(
name="rembg", name="rembg",
version="1.0.24", version="1.0.25",
description="Remove image background", description="Remove image background",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -10,6 +10,13 @@ from ..bg import remove
def main(): def main():
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument( ap.add_argument(
@ -17,7 +24,7 @@ def main():
"--model", "--model",
default="u2net", default="u2net",
type=str, type=str,
choices=("u2net", "u2net_human_seg", "u2netp"), choices=model_choices,
help="The model name.", help="The model name.",
) )
@ -66,8 +73,8 @@ def main():
ap.add_argument( ap.add_argument(
"-p", "-p",
"--path", "--path",
nargs="+", nargs=2,
help="Path of a file or a folder of files.", help="An input folder and an output folder.",
) )
ap.add_argument( ap.add_argument(
@ -94,15 +101,20 @@ def main():
if args.path: if args.path:
full_paths = [os.path.abspath(path) for path in args.path] full_paths = [os.path.abspath(path) for path in args.path]
input_paths = [full_paths[0]]
output_path = full_paths[1]
if not os.path.exists(output_path):
os.makedirs(output_path)
files = set() files = set()
for path in full_paths: for path in input_paths:
if os.path.isfile(path): if os.path.isfile(path):
files.add(path) files.add(path)
else: else:
full_paths += set(glob.glob(path + "/*")) - set( input_paths += set(glob.glob(path + "/*"))
glob.glob(path + "/*.out.png")
)
for fi in tqdm(files): for fi in tqdm(files):
fi_type = filetype.guess(fi) fi_type = filetype.guess(fi)
@ -113,7 +125,7 @@ def main():
continue continue
with open(fi, "rb") as input: with open(fi, "rb") as input:
with open(os.path.splitext(fi)[0] + ".out.png", "wb") as output: with open(os.path.join(output_path, os.path.splitext(os.path.basename(fi))[0] + ".png"), "wb") as output:
w( w(
output, output,
remove( remove(

View File

@ -1,3 +1,5 @@
import os
import glob
import argparse import argparse
from io import BytesIO from io import BytesIO
from urllib.parse import unquote_plus from urllib.parse import unquote_plus
@ -38,8 +40,14 @@ def index():
az = request.values.get("az", type=int, default=1000) az = request.values.get("az", type=int, default=1000)
model = request.args.get("model", type=str, default="u2net") model = request.args.get("model", type=str, default="u2net")
if model not in ("u2net", "u2net_human_seg", "u2netp"): model_path = os.environ.get(
return {"error": "invalid query param 'model'"}, 400 "U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))]
if model not in model_choices:
return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400
try: try:
return send_file( return send_file(