diff --git a/requirements.txt b/requirements.txt index 7dafb25..c6ee0c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ requests==2.24.0 scipy==1.5.4 pymatting==1.1.1 filetype==1.0.7 +hsh==1.1.0 diff --git a/setup.py b/setup.py index 97a7051..ce8d0fb 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ with open("requirements.txt") as f: setup( name="rembg", - version="1.0.16", + version="1.0.17", description="Remove image background", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/rembg/bg.py b/src/rembg/bg.py index 899a63b..fb9d3ab 100644 --- a/src/rembg/bg.py +++ b/src/rembg/bg.py @@ -1,6 +1,6 @@ +import functools import io -import functools import numpy as np from PIL import Image from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf @@ -12,7 +12,11 @@ from .u2net import detect def alpha_matting_cutout( - img, mask, foreground_threshold, background_threshold, erode_structure_size, + img, + mask, + foreground_threshold, + background_threshold, + erode_structure_size, ): base_size = (1000, 1000) size = img.size diff --git a/src/rembg/cmd/cli.py b/src/rembg/cmd/cli.py index f0b89c4..45a4d6d 100644 --- a/src/rembg/cmd/cli.py +++ b/src/rembg/cmd/cli.py @@ -1,8 +1,9 @@ import argparse import glob import os -import filetype from distutils.util import strtobool + +import filetype from tqdm import tqdm from ..bg import remove @@ -55,7 +56,10 @@ def main(): ) ap.add_argument( - "-p", "--path", nargs="+", help="Path of a file or a folder of files.", + "-p", + "--path", + nargs="+", + help="Path of a file or a folder of files.", ) ap.add_argument( @@ -95,7 +99,7 @@ def main(): if fi_type is None: continue - elif fi_type.mime.find('image') < 0: + elif fi_type.mime.find("image") < 0: continue with open(fi, "rb") as input: diff --git a/src/rembg/cmd/server.py b/src/rembg/cmd/server.py index eff4f9e..6e8ff24 100644 --- a/src/rembg/cmd/server.py +++ b/src/rembg/cmd/server.py @@ -36,7 +36,10 @@ def index(): return {"error": "invalid query param 'model'"}, 400 try: - return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",) + return send_file( + BytesIO(remove(file_content, model)), + mimetype="image/png", + ) except Exception as e: app.logger.exception(e, exc_info=True) return {"error": "oops, something went wrong!"}, 500 @@ -46,11 +49,19 @@ def main(): ap = argparse.ArgumentParser() ap.add_argument( - "-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.", + "-a", + "--addr", + default="0.0.0.0", + type=str, + help="The IP address to bind to.", ) ap.add_argument( - "-p", "--port", default=5000, type=int, help="The port to bind to.", + "-p", + "--port", + default=5000, + type=int, + help="The port to bind to.", ) args = ap.parse_args() diff --git a/src/rembg/u2net/detect.py b/src/rembg/u2net/detect.py index 81dd00c..4e2d88b 100644 --- a/src/rembg/u2net/detect.py +++ b/src/rembg/u2net/detect.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torchvision +from hsh.library.hash import Hasher from PIL import Image from skimage import transform from torchvision import transforms @@ -18,8 +19,8 @@ from . import data_loader, u2net def download_file_from_google_drive(id, fname, destination): - if os.path.exists(destination): - return + head, tail = os.path.split(destination) + os.makedirs(head, exist_ok=True) URL = "https://docs.google.com/uc?export=download" @@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination): total = int(response.headers.get("content-length", 0)) with open(destination, "wb") as file, tqdm( - desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024, + desc=f"Downloading {tail} to {head}", + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, ) as bar: for data in response.iter_content(chunk_size=1024): size = file.write(data) @@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination): def load_model(model_name: str = "u2net"): - os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True) + hasher = Hasher() if model_name == "u2netp": net = u2net.U2NETP(3, 1) - path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) - download_file_from_google_drive( - "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path, + path = os.environ.get( + "U2NETP_PATH", + os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), ) + if ( + not os.path.exists(path) + or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e" + ): + download_file_from_google_drive( + "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", + "u2netp.pth", + path, + ) + elif model_name == "u2net": net = u2net.U2NET(3, 1) - path = os.path.expanduser(os.path.join("~", ".u2net", model_name)) - download_file_from_google_drive( - "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path, + path = os.environ.get( + "U2NET_PATH", + os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), ) + if ( + not os.path.exists(path) + or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55" + ): + download_file_from_google_drive( + "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", + "u2net.pth", + path, + ) else: print("Choose between u2net or u2netp", file=sys.stderr) @@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"): net.load_state_dict(torch.load(path)) net.to(torch.device("cuda")) else: - net.load_state_dict(torch.load(path, map_location="cpu",)) + net.load_state_dict( + torch.load( + path, + map_location="cpu", + ) + ) except FileNotFoundError: raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"