add model var envs

This commit is contained in:
Daniel Gatis 2020-12-12 18:56:09 -03:00
parent 0a96555440
commit 58fd707872
6 changed files with 69 additions and 20 deletions

View File

@ -10,3 +10,4 @@ requests==2.24.0
scipy==1.5.4
pymatting==1.1.1
filetype==1.0.7
hsh==1.1.0

View File

@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup(
name="rembg",
version="1.0.16",
version="1.0.17",
description="Remove image background",
long_description=long_description,
long_description_content_type="text/markdown",

View File

@ -1,6 +1,6 @@
import functools
import io
import functools
import numpy as np
from PIL import Image
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
@ -12,7 +12,11 @@ from .u2net import detect
def alpha_matting_cutout(
img, mask, foreground_threshold, background_threshold, erode_structure_size,
img,
mask,
foreground_threshold,
background_threshold,
erode_structure_size,
):
base_size = (1000, 1000)
size = img.size

View File

@ -1,8 +1,9 @@
import argparse
import glob
import os
import filetype
from distutils.util import strtobool
import filetype
from tqdm import tqdm
from ..bg import remove
@ -55,7 +56,10 @@ def main():
)
ap.add_argument(
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
"-p",
"--path",
nargs="+",
help="Path of a file or a folder of files.",
)
ap.add_argument(
@ -95,7 +99,7 @@ def main():
if fi_type is None:
continue
elif fi_type.mime.find('image') < 0:
elif fi_type.mime.find("image") < 0:
continue
with open(fi, "rb") as input:

View File

@ -36,7 +36,10 @@ def index():
return {"error": "invalid query param 'model'"}, 400
try:
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
return send_file(
BytesIO(remove(file_content, model)),
mimetype="image/png",
)
except Exception as e:
app.logger.exception(e, exc_info=True)
return {"error": "oops, something went wrong!"}, 500
@ -46,11 +49,19 @@ def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.",
"-a",
"--addr",
default="0.0.0.0",
type=str,
help="The IP address to bind to.",
)
ap.add_argument(
"-p", "--port", default=5000, type=int, help="The port to bind to.",
"-p",
"--port",
default=5000,
type=int,
help="The port to bind to.",
)
args = ap.parse_args()

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from hsh.library.hash import Hasher
from PIL import Image
from skimage import transform
from torchvision import transforms
@ -18,8 +19,8 @@ from . import data_loader, u2net
def download_file_from_google_drive(id, fname, destination):
if os.path.exists(destination):
return
head, tail = os.path.split(destination)
os.makedirs(head, exist_ok=True)
URL = "https://docs.google.com/uc?export=download"
@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination):
total = int(response.headers.get("content-length", 0))
with open(destination, "wb") as file, tqdm(
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
desc=f"Downloading {tail} to {head}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination):
def load_model(model_name: str = "u2net"):
os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
hasher = Hasher()
if model_name == "u2netp":
net = u2net.U2NETP(3, 1)
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
download_file_from_google_drive(
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
download_file_from_google_drive(
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
"u2netp.pth",
path,
)
elif model_name == "u2net":
net = u2net.U2NET(3, 1)
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
download_file_from_google_drive(
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
download_file_from_google_drive(
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
"u2net.pth",
path,
)
else:
print("Choose between u2net or u2netp", file=sys.stderr)
@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"):
net.load_state_dict(torch.load(path))
net.to(torch.device("cuda"))
else:
net.load_state_dict(torch.load(path, map_location="cpu",))
net.load_state_dict(
torch.load(
path,
map_location="cpu",
)
)
except FileNotFoundError:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"