mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-15 00:25:59 +08:00
add u2net_human_seg model
This commit is contained in:
parent
63f52a3a6a
commit
fd3b5a5e8f
@ -72,6 +72,8 @@ def naive_cutout(img, mask):
|
||||
def get_model(model_name):
|
||||
if model_name == "u2netp":
|
||||
return detect.load_model(model_name="u2netp")
|
||||
if model_name == "u2net_human_seg":
|
||||
return detect.load_model(model_name="u2net_human_seg")
|
||||
else:
|
||||
return detect.load_model(model_name="u2net")
|
||||
|
||||
|
@ -17,7 +17,7 @@ def main():
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=str,
|
||||
choices=("u2net", "u2netp"),
|
||||
choices=("u2net", "u2net_human_seg", "u2netp"),
|
||||
help="The model name.",
|
||||
)
|
||||
|
||||
|
@ -38,7 +38,7 @@ def index():
|
||||
az = request.values.get("az", type=int, default=1000)
|
||||
|
||||
model = request.args.get("model", type=str, default="u2net")
|
||||
if model not in ("u2net", "u2netp"):
|
||||
if model not in ("u2net", "u2net_human_seg", "u2netp"):
|
||||
return {"error": "invalid query param 'model'"}, 400
|
||||
|
||||
try:
|
||||
|
@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"):
|
||||
"u2net.pth",
|
||||
path,
|
||||
)
|
||||
|
||||
elif model_name == "u2net_human_seg":
|
||||
net = u2net.U2NET(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|
||||
"u2net_human_seg.pth",
|
||||
path,
|
||||
)
|
||||
else:
|
||||
print("Choose between u2net or u2netp", file=sys.stderr)
|
||||
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
|
Loading…
x
Reference in New Issue
Block a user