diff --git a/src/rembg/bg.py b/src/rembg/bg.py index 04a9ed2..2177ddf 100644 --- a/src/rembg/bg.py +++ b/src/rembg/bg.py @@ -72,6 +72,8 @@ def naive_cutout(img, mask): def get_model(model_name): if model_name == "u2netp": return detect.load_model(model_name="u2netp") + if model_name == "u2net_human_seg": + return detect.load_model(model_name="u2net_human_seg") else: return detect.load_model(model_name="u2net") diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py index 08529f0..e51377d 100644 --- a/src/rembg/cmd/cli.py +++ b/src/rembg/cmd/cli.py @@ -17,7 +17,7 @@ def main(): "--model", default="u2net", type=str, - choices=("u2net", "u2netp"), + choices=("u2net", "u2net_human_seg", "u2netp"), help="The model name.", ) diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py index 022d773..d686ec9 100644 --- a/src/rembg/cmd/server.py +++ b/src/rembg/cmd/server.py @@ -38,7 +38,7 @@ def index(): az = request.values.get("az", type=int, default=1000) model = request.args.get("model", type=str, default="u2net") - if model not in ("u2net", "u2netp"): + if model not in ("u2net", "u2net_human_seg", "u2netp"): return {"error": "invalid query param 'model'"}, 400 try: diff --git a/src/rembg/u2net/detect.py b/src/rembg/u2net/detect.py index 4e2d88b..6fdc635 100644 --- a/src/rembg/u2net/detect.py +++ b/src/rembg/u2net/detect.py @@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"): "u2net.pth", path, ) + + elif model_name == "u2net_human_seg": + net = u2net.U2NET(3, 1) + path = os.environ.get( + "U2NET_PATH", + os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), + ) + if ( + not os.path.exists(path) + or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a" + ): + download_file_from_google_drive( + "1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P", + "u2net_human_seg.pth", + path, + ) else: - print("Choose between u2net or u2netp", file=sys.stderr) + print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr) try: if torch.cuda.is_available():