From b07192d16f69d4a9b6da8513a493c5d3c379828e Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Sun, 4 Apr 2021 20:07:11 -0300 Subject: [PATCH] fix invalid model fresh install --- setup.py | 2 +- src/rembg/cmd/cli.py | 7 ++++--- src/rembg/cmd/server.py | 2 ++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 2525f12..cd333ea 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open("requirements.txt") as f: setup( name="rembg", - version="1.0.25", + version="1.0.26", description="Remove image background", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py index 6d13b82..adf34dc 100644 --- a/src/rembg/cmd/cli.py +++ b/src/rembg/cmd/cli.py @@ -15,7 +15,8 @@ def main(): os.path.expanduser(os.path.join("~", ".u2net")), ) model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))] - + if len(model_choices) == 0: + model_choices = ["u2net", "u2netp", "u2net_human_seg"] ap = argparse.ArgumentParser() @@ -101,8 +102,8 @@ def main(): if args.path: full_paths = [os.path.abspath(path) for path in args.path] - - input_paths = [full_paths[0]] + + input_paths = [full_paths[0]] output_path = full_paths[1] if not os.path.exists(output_path): diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py index 6c9900b..9e3a926 100644 --- a/src/rembg/cmd/server.py +++ b/src/rembg/cmd/server.py @@ -45,6 +45,8 @@ def index(): os.path.expanduser(os.path.join("~", ".u2net")), ) model_choices = [os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*"))] + if len(model_choices) == 0: + model_choices = ["u2net", "u2netp", "u2net_human_seg"] if model not in model_choices: return {"error": f"invalid query param 'model'. Available options are {model_choices}"}, 400