add a cache for models

This commit is contained in:
Daniel Gatis 2020-11-07 16:49:54 -03:00
parent 20fa62da03
commit 8d5e4aba5d
3 changed files with 36 additions and 23 deletions

View File

@ -11,7 +11,7 @@ with open("requirements.txt") as f:
setup( setup(
name="rembg", name="rembg",
version="1.0.13", version="1.0.14",
description="Remove image background", description="Remove image background",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -1,5 +1,6 @@
import io import io
import functools
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
@ -9,9 +10,6 @@ from scipy.ndimage.morphology import binary_erosion
from .u2net import detect from .u2net import detect
model_u2net = detect.load_model(model_name="u2net")
model_u2netp = detect.load_model(model_name="u2netp")
def alpha_matting_cutout( def alpha_matting_cutout(
img, mask, foreground_threshold, background_threshold, erode_structure_size, img, mask, foreground_threshold, background_threshold, erode_structure_size,
@ -66,6 +64,14 @@ def naive_cutout(img, mask):
return cutout return cutout
@functools.lru_cache
def get_model(model_name):
if model_name == "u2netp":
return detect.load_model(model_name="u2netp")
else:
return detect.load_model(model_name="u2net")
def remove( def remove(
data, data,
model_name="u2net", model_name="u2net",
@ -74,11 +80,7 @@ def remove(
alpha_matting_background_threshold=10, alpha_matting_background_threshold=10,
alpha_matting_erode_structure_size=10, alpha_matting_erode_structure_size=10,
): ):
model = model_u2net model = get_model(model_name)
if model == "u2netp":
model = model_u2netp
img = Image.open(io.BytesIO(data)).convert("RGB") img = Image.open(io.BytesIO(data)).convert("RGB")
mask = detect.predict(model, np.array(img)).convert("L") mask = detect.predict(model, np.array(img)).convert("L")

View File

@ -17,16 +17,31 @@ from tqdm import tqdm
from . import data_loader, u2net from . import data_loader, u2net
def download(url, fname, path): def download_file_from_google_drive(id, fname, destination):
if os.path.exists(path): if os.path.exists(destination):
return return
resp = requests.get(url, stream=True) URL = "https://docs.google.com/uc?export=download"
total = int(resp.headers.get("content-length", 0))
with open(path, "wb") as file, tqdm( session = requests.Session()
response = session.get(URL, params={"id": id}, stream=True)
token = None
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
if token:
params = {"id": id, "confirm": token}
response = session.get(URL, params=params, stream=True)
total = int(response.headers.get("content-length", 0))
with open(destination, "wb") as file, tqdm(
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024, desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
) as bar: ) as bar:
for data in resp.iter_content(chunk_size=1024): for data in response.iter_content(chunk_size=1024):
size = file.write(data) size = file.write(data)
bar.update(size) bar.update(size)
@ -37,18 +52,14 @@ def load_model(model_name: str = "u2net"):
if model_name == "u2netp": if model_name == "u2netp":
net = u2net.U2NETP(3, 1) net = u2net.U2NETP(3, 1)
path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
download( download_file_from_google_drive(
"https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1", "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
"u2netp.pth",
path,
) )
elif model_name == "u2net": elif model_name == "u2net":
net = u2net.U2NET(3, 1) net = u2net.U2NET(3, 1)
path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
download( download_file_from_google_drive(
"https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1", "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
"u2net.pth",
path,
) )
else: else:
print("Choose between u2net or u2netp", file=sys.stderr) print("Choose between u2net or u2netp", file=sys.stderr)