From 9405b96d80ce819a6b72a1f5c3ace16642ca7248 Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Sat, 9 Oct 2021 10:40:22 -0300 Subject: [PATCH] fix model choices --- requirements.txt | 4 ++-- setup.py | 2 +- src/rembg/cmd/cli.py | 4 ++-- src/rembg/cmd/server.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0140da8..4b3abe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,8 @@ flask>=1.1.2 numpy>=1.19.5 pillow>=8.0.1 scikit-image>=0.17.2 -torch>=1.7.0 -torchvision>=0.8.1 +torch>=1.9.1 +torchvision>=0.10.1 waitress>=1.4.4 tqdm>=4.51.0 requests>=2.24.0 diff --git a/setup.py b/setup.py index 40c3c35..4d7e36d 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open("requirements.txt") as f: setup( name="rembg", - version="1.0.27", + version="1.0.28", description="Remove image background", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py index 73c69da..319a9dd 100644 --- a/src/rembg/cmd/cli.py +++ b/src/rembg/cmd/cli.py @@ -18,8 +18,8 @@ def main(): os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*")) ] - if len(model_choices) == 0: - model_choices = ["u2net", "u2netp", "u2net_human_seg"] + + model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"])) ap = argparse.ArgumentParser() diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py index 98cc7cd..61a048c 100644 --- a/src/rembg/cmd/server.py +++ b/src/rembg/cmd/server.py @@ -52,8 +52,8 @@ def index(): os.path.splitext(os.path.basename(x))[0] for x in set(glob.glob(model_path + "/*")) ] - if len(model_choices) == 0: - model_choices = ["u2net", "u2netp", "u2net_human_seg"] + + model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"])) if model not in model_choices: return {