add model var envs

This commit is contained in:
Daniel Gatis 2020-12-12 18:56:09 -03:00
parent 0a96555440
commit 58fd707872
6 changed files with 69 additions and 20 deletions

View File

@ -10,3 +10,4 @@ requests==2.24.0
scipy==1.5.4 scipy==1.5.4
pymatting==1.1.1 pymatting==1.1.1
filetype==1.0.7 filetype==1.0.7
hsh==1.1.0

View File

@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup( setup(
name="rembg", name="rembg",
version="1.0.16", version="1.0.17",
description="Remove image background", description="Remove image background",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -1,6 +1,6 @@
import functools
import io import io
import functools
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
@ -12,7 +12,11 @@ from .u2net import detect
def alpha_matting_cutout( def alpha_matting_cutout(
img, mask, foreground_threshold, background_threshold, erode_structure_size, img,
mask,
foreground_threshold,
background_threshold,
erode_structure_size,
): ):
base_size = (1000, 1000) base_size = (1000, 1000)
size = img.size size = img.size

View File

@ -1,8 +1,9 @@
import argparse import argparse
import glob import glob
import os import os
import filetype
from distutils.util import strtobool from distutils.util import strtobool
import filetype
from tqdm import tqdm from tqdm import tqdm
from ..bg import remove from ..bg import remove
@ -55,7 +56,10 @@ def main():
) )
ap.add_argument( ap.add_argument(
"-p", "--path", nargs="+", help="Path of a file or a folder of files.", "-p",
"--path",
nargs="+",
help="Path of a file or a folder of files.",
) )
ap.add_argument( ap.add_argument(
@ -95,7 +99,7 @@ def main():
if fi_type is None: if fi_type is None:
continue continue
elif fi_type.mime.find('image') < 0: elif fi_type.mime.find("image") < 0:
continue continue
with open(fi, "rb") as input: with open(fi, "rb") as input:

View File

@ -36,7 +36,10 @@ def index():
return {"error": "invalid query param 'model'"}, 400 return {"error": "invalid query param 'model'"}, 400
try: try:
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",) return send_file(
BytesIO(remove(file_content, model)),
mimetype="image/png",
)
except Exception as e: except Exception as e:
app.logger.exception(e, exc_info=True) app.logger.exception(e, exc_info=True)
return {"error": "oops, something went wrong!"}, 500 return {"error": "oops, something went wrong!"}, 500
@ -46,11 +49,19 @@ def main():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument( ap.add_argument(
"-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.", "-a",
"--addr",
default="0.0.0.0",
type=str,
help="The IP address to bind to.",
) )
ap.add_argument( ap.add_argument(
"-p", "--port", default=5000, type=int, help="The port to bind to.", "-p",
"--port",
default=5000,
type=int,
help="The port to bind to.",
) )
args = ap.parse_args() args = ap.parse_args()

View File

@ -9,6 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision import torchvision
from hsh.library.hash import Hasher
from PIL import Image from PIL import Image
from skimage import transform from skimage import transform
from torchvision import transforms from torchvision import transforms
@ -18,8 +19,8 @@ from . import data_loader, u2net
def download_file_from_google_drive(id, fname, destination): def download_file_from_google_drive(id, fname, destination):
if os.path.exists(destination): head, tail = os.path.split(destination)
return os.makedirs(head, exist_ok=True)
URL = "https://docs.google.com/uc?export=download" URL = "https://docs.google.com/uc?export=download"
@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination):
total = int(response.headers.get("content-length", 0)) total = int(response.headers.get("content-length", 0))
with open(destination, "wb") as file, tqdm( with open(destination, "wb") as file, tqdm(
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024, desc=f"Downloading {tail} to {head}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar: ) as bar:
for data in response.iter_content(chunk_size=1024): for data in response.iter_content(chunk_size=1024):
size = file.write(data) size = file.write(data)
@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination):
def load_model(model_name: str = "u2net"): def load_model(model_name: str = "u2net"):
os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True) hasher = Hasher()
if model_name == "u2netp": if model_name == "u2netp":
net = u2net.U2NETP(3, 1) net = u2net.U2NETP(3, 1)
path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) path = os.environ.get(
download_file_from_google_drive( "U2NETP_PATH",
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path, os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
) )
if (
not os.path.exists(path)
or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
download_file_from_google_drive(
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
"u2netp.pth",
path,
)
elif model_name == "u2net": elif model_name == "u2net":
net = u2net.U2NET(3, 1) net = u2net.U2NET(3, 1)
path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) path = os.environ.get(
download_file_from_google_drive( "U2NET_PATH",
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path, os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
) )
if (
not os.path.exists(path)
or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
download_file_from_google_drive(
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
"u2net.pth",
path,
)
else: else:
print("Choose between u2net or u2netp", file=sys.stderr) print("Choose between u2net or u2netp", file=sys.stderr)
@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"):
net.load_state_dict(torch.load(path)) net.load_state_dict(torch.load(path))
net.to(torch.device("cuda")) net.to(torch.device("cuda"))
else: else:
net.load_state_dict(torch.load(path, map_location="cpu",)) net.load_state_dict(
torch.load(
path,
map_location="cpu",
)
)
except FileNotFoundError: except FileNotFoundError:
raise FileNotFoundError( raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth" errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"