mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-17 01:55:59 +08:00
add a cache for models
This commit is contained in:
parent
20fa62da03
commit
8d5e4aba5d
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
||||
|
||||
setup(
|
||||
name="rembg",
|
||||
version="1.0.13",
|
||||
version="1.0.14",
|
||||
description="Remove image background",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
@ -1,5 +1,6 @@
|
||||
import io
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||
@ -9,9 +10,6 @@ from scipy.ndimage.morphology import binary_erosion
|
||||
|
||||
from .u2net import detect
|
||||
|
||||
model_u2net = detect.load_model(model_name="u2net")
|
||||
model_u2netp = detect.load_model(model_name="u2netp")
|
||||
|
||||
|
||||
def alpha_matting_cutout(
|
||||
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
||||
@ -66,6 +64,14 @@ def naive_cutout(img, mask):
|
||||
return cutout
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_model(model_name):
|
||||
if model_name == "u2netp":
|
||||
return detect.load_model(model_name="u2netp")
|
||||
else:
|
||||
return detect.load_model(model_name="u2net")
|
||||
|
||||
|
||||
def remove(
|
||||
data,
|
||||
model_name="u2net",
|
||||
@ -74,11 +80,7 @@ def remove(
|
||||
alpha_matting_background_threshold=10,
|
||||
alpha_matting_erode_structure_size=10,
|
||||
):
|
||||
model = model_u2net
|
||||
|
||||
if model == "u2netp":
|
||||
model = model_u2netp
|
||||
|
||||
model = get_model(model_name)
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||
mask = detect.predict(model, np.array(img)).convert("L")
|
||||
|
||||
|
@ -17,16 +17,31 @@ from tqdm import tqdm
|
||||
from . import data_loader, u2net
|
||||
|
||||
|
||||
def download(url, fname, path):
|
||||
if os.path.exists(path):
|
||||
def download_file_from_google_drive(id, fname, destination):
|
||||
if os.path.exists(destination):
|
||||
return
|
||||
|
||||
resp = requests.get(url, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
with open(path, "wb") as file, tqdm(
|
||||
URL = "https://docs.google.com/uc?export=download"
|
||||
|
||||
session = requests.Session()
|
||||
response = session.get(URL, params={"id": id}, stream=True)
|
||||
|
||||
token = None
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith("download_warning"):
|
||||
token = value
|
||||
break
|
||||
|
||||
if token:
|
||||
params = {"id": id, "confirm": token}
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
|
||||
with open(destination, "wb") as file, tqdm(
|
||||
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
||||
@ -37,18 +52,14 @@ def load_model(model_name: str = "u2net"):
|
||||
if model_name == "u2netp":
|
||||
net = u2net.U2NETP(3, 1)
|
||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
||||
download(
|
||||
"https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
|
||||
"u2netp.pth",
|
||||
path,
|
||||
download_file_from_google_drive(
|
||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
|
||||
)
|
||||
elif model_name == "u2net":
|
||||
net = u2net.U2NET(3, 1)
|
||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
||||
download(
|
||||
"https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
|
||||
"u2net.pth",
|
||||
path,
|
||||
download_file_from_google_drive(
|
||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
|
||||
)
|
||||
else:
|
||||
print("Choose between u2net or u2netp", file=sys.stderr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user