add u2net_human_seg model

This commit is contained in:
Daniel Gatis 2021-02-21 03:04:25 -03:00
parent 63f52a3a6a
commit fd3b5a5e8f
4 changed files with 21 additions and 3 deletions

View File

@ -72,6 +72,8 @@ def naive_cutout(img, mask):
def get_model(model_name): def get_model(model_name):
if model_name == "u2netp": if model_name == "u2netp":
return detect.load_model(model_name="u2netp") return detect.load_model(model_name="u2netp")
if model_name == "u2net_human_seg":
return detect.load_model(model_name="u2net_human_seg")
else: else:
return detect.load_model(model_name="u2net") return detect.load_model(model_name="u2net")

View File

@ -17,7 +17,7 @@ def main():
"--model", "--model",
default="u2net", default="u2net",
type=str, type=str,
choices=("u2net", "u2netp"), choices=("u2net", "u2net_human_seg", "u2netp"),
help="The model name.", help="The model name.",
) )

View File

@ -38,7 +38,7 @@ def index():
az = request.values.get("az", type=int, default=1000) az = request.values.get("az", type=int, default=1000)
model = request.args.get("model", type=str, default="u2net") model = request.args.get("model", type=str, default="u2net")
if model not in ("u2net", "u2netp"): if model not in ("u2net", "u2net_human_seg", "u2netp"):
return {"error": "invalid query param 'model'"}, 400 return {"error": "invalid query param 'model'"}, 400
try: try:

View File

@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"):
"u2net.pth", "u2net.pth",
path, path,
) )
elif model_name == "u2net_human_seg":
net = u2net.U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
):
download_file_from_google_drive(
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
"u2net_human_seg.pth",
path,
)
else: else:
print("Choose between u2net or u2netp", file=sys.stderr) print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
try: try:
if torch.cuda.is_available(): if torch.cuda.is_available():