mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-15 05:06:10 +08:00
add u2net_human_seg model
This commit is contained in:
parent
63f52a3a6a
commit
fd3b5a5e8f
@ -72,6 +72,8 @@ def naive_cutout(img, mask):
|
|||||||
def get_model(model_name):
|
def get_model(model_name):
|
||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
return detect.load_model(model_name="u2netp")
|
return detect.load_model(model_name="u2netp")
|
||||||
|
if model_name == "u2net_human_seg":
|
||||||
|
return detect.load_model(model_name="u2net_human_seg")
|
||||||
else:
|
else:
|
||||||
return detect.load_model(model_name="u2net")
|
return detect.load_model(model_name="u2net")
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ def main():
|
|||||||
"--model",
|
"--model",
|
||||||
default="u2net",
|
default="u2net",
|
||||||
type=str,
|
type=str,
|
||||||
choices=("u2net", "u2netp"),
|
choices=("u2net", "u2net_human_seg", "u2netp"),
|
||||||
help="The model name.",
|
help="The model name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def index():
|
|||||||
az = request.values.get("az", type=int, default=1000)
|
az = request.values.get("az", type=int, default=1000)
|
||||||
|
|
||||||
model = request.args.get("model", type=str, default="u2net")
|
model = request.args.get("model", type=str, default="u2net")
|
||||||
if model not in ("u2net", "u2netp"):
|
if model not in ("u2net", "u2net_human_seg", "u2netp"):
|
||||||
return {"error": "invalid query param 'model'"}, 400
|
return {"error": "invalid query param 'model'"}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -85,8 +85,24 @@ def load_model(model_name: str = "u2net"):
|
|||||||
"u2net.pth",
|
"u2net.pth",
|
||||||
path,
|
path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif model_name == "u2net_human_seg":
|
||||||
|
net = u2net.U2NET(3, 1)
|
||||||
|
path = os.environ.get(
|
||||||
|
"U2NET_PATH",
|
||||||
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not os.path.exists(path)
|
||||||
|
or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a"
|
||||||
|
):
|
||||||
|
download_file_from_google_drive(
|
||||||
|
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|
||||||
|
"u2net_human_seg.pth",
|
||||||
|
path,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Choose between u2net or u2netp", file=sys.stderr)
|
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user