From 8d5e4aba5db7df90fe87f44fee605229c122984d Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Sat, 7 Nov 2020 16:49:54 -0300 Subject: [PATCH] add a cache for models --- setup.py | 2 +- src/rembg/bg.py | 18 ++++++++++-------- src/rembg/u2net/detect.py | 39 +++++++++++++++++++++++++-------------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index a0a5034..5b1e3e4 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open("requirements.txt") as f: setup( name="rembg", - version="1.0.13", + version="1.0.14", description="Remove image background", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/rembg/bg.py b/src/rembg/bg.py index f364eeb..1b3c228 100644 --- a/src/rembg/bg.py +++ b/src/rembg/bg.py @@ -1,5 +1,6 @@ import io +import functools import numpy as np from PIL import Image from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf @@ -9,9 +10,6 @@ from scipy.ndimage.morphology import binary_erosion from .u2net import detect -model_u2net = detect.load_model(model_name="u2net") -model_u2netp = detect.load_model(model_name="u2netp") - def alpha_matting_cutout( img, mask, foreground_threshold, background_threshold, erode_structure_size, @@ -66,6 +64,14 @@ def naive_cutout(img, mask): return cutout +@functools.lru_cache +def get_model(model_name): + if model_name == "u2netp": + return detect.load_model(model_name="u2netp") + else: + return detect.load_model(model_name="u2net") + + def remove( data, model_name="u2net", @@ -74,11 +80,7 @@ def remove( alpha_matting_background_threshold=10, alpha_matting_erode_structure_size=10, ): - model = model_u2net - - if model == "u2netp": - model = model_u2netp - + model = get_model(model_name) img = Image.open(io.BytesIO(data)).convert("RGB") mask = detect.predict(model, np.array(img)).convert("L") diff --git a/src/rembg/u2net/detect.py b/src/rembg/u2net/detect.py index 5afa475..81dd00c 100644 --- a/src/rembg/u2net/detect.py +++ b/src/rembg/u2net/detect.py @@ -17,16 +17,31 @@ from tqdm import tqdm from . import data_loader, u2net -def download(url, fname, path): - if os.path.exists(path): +def download_file_from_google_drive(id, fname, destination): + if os.path.exists(destination): return - resp = requests.get(url, stream=True) - total = int(resp.headers.get("content-length", 0)) - with open(path, "wb") as file, tqdm( + URL = "https://docs.google.com/uc?export=download" + + session = requests.Session() + response = session.get(URL, params={"id": id}, stream=True) + + token = None + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + token = value + break + + if token: + params = {"id": id, "confirm": token} + response = session.get(URL, params=params, stream=True) + + total = int(response.headers.get("content-length", 0)) + + with open(destination, "wb") as file, tqdm( desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024, ) as bar: - for data in resp.iter_content(chunk_size=1024): + for data in response.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) @@ -37,18 +52,14 @@ def load_model(model_name: str = "u2net"): if model_name == "u2netp": net = u2net.U2NETP(3, 1) path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) - download( - "https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1", - "u2netp.pth", - path, + download_file_from_google_drive( + "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path, ) elif model_name == "u2net": net = u2net.U2NET(3, 1) path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) - download( - "https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1", - "u2net.pth", - path, + download_file_from_google_drive( + "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path, ) else: print("Choose between u2net or u2netp", file=sys.stderr)