add u2net_human_seg model

This commit is contained in:
Daniel Gatis 2021-02-21 03:04:25 -03:00
parent 63f52a3a6a
commit fd3b5a5e8f
4 changed files with 21 additions and 3 deletions

View File

@ -72,6 +72,8 @@ def naive_cutout(img, mask):
def get_model(model_name):
if model_name == "u2netp":
return detect.load_model(model_name="u2netp")
if model_name == "u2net_human_seg":
return detect.load_model(model_name="u2net_human_seg")
else:
return detect.load_model(model_name="u2net")

View File

@ -17,7 +17,7 @@ def main():
"--model",
default="u2net",
type=str,
choices=("u2net", "u2netp"),
choices=("u2net", "u2net_human_seg", "u2netp"),
help="The model name.",
)

View File

@ -38,7 +38,7 @@ def index():
az = request.values.get("az", type=int, default=1000)
model = request.args.get("model", type=str, default="u2net")
if model not in ("u2net", "u2netp"):
if model not in ("u2net", "u2net_human_seg", "u2netp"):
return {"error": "invalid query param 'model'"}, 400
try:

View File

@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"):
"u2net.pth",
path,
)
elif model_name == "u2net_human_seg":
net = u2net.U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
download_file_from_google_drive(
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
"u2net_human_seg.pth",
path,
)
else:
print("Choose between u2net or u2netp", file=sys.stderr)
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
try:
if torch.cuda.is_available():