mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-18 01:35:57 +08:00
add a cache for models
This commit is contained in:
parent
20fa62da03
commit
8d5e4aba5d
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="rembg",
|
name="rembg",
|
||||||
version="1.0.13",
|
version="1.0.14",
|
||||||
description="Remove image background",
|
description="Remove image background",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
|
import functools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||||
@ -9,9 +10,6 @@ from scipy.ndimage.morphology import binary_erosion
|
|||||||
|
|
||||||
from .u2net import detect
|
from .u2net import detect
|
||||||
|
|
||||||
model_u2net = detect.load_model(model_name="u2net")
|
|
||||||
model_u2netp = detect.load_model(model_name="u2netp")
|
|
||||||
|
|
||||||
|
|
||||||
def alpha_matting_cutout(
|
def alpha_matting_cutout(
|
||||||
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
||||||
@ -66,6 +64,14 @@ def naive_cutout(img, mask):
|
|||||||
return cutout
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def get_model(model_name):
|
||||||
|
if model_name == "u2netp":
|
||||||
|
return detect.load_model(model_name="u2netp")
|
||||||
|
else:
|
||||||
|
return detect.load_model(model_name="u2net")
|
||||||
|
|
||||||
|
|
||||||
def remove(
|
def remove(
|
||||||
data,
|
data,
|
||||||
model_name="u2net",
|
model_name="u2net",
|
||||||
@ -74,11 +80,7 @@ def remove(
|
|||||||
alpha_matting_background_threshold=10,
|
alpha_matting_background_threshold=10,
|
||||||
alpha_matting_erode_structure_size=10,
|
alpha_matting_erode_structure_size=10,
|
||||||
):
|
):
|
||||||
model = model_u2net
|
model = get_model(model_name)
|
||||||
|
|
||||||
if model == "u2netp":
|
|
||||||
model = model_u2netp
|
|
||||||
|
|
||||||
img = Image.open(io.BytesIO(data)).convert("RGB")
|
img = Image.open(io.BytesIO(data)).convert("RGB")
|
||||||
mask = detect.predict(model, np.array(img)).convert("L")
|
mask = detect.predict(model, np.array(img)).convert("L")
|
||||||
|
|
||||||
|
@ -17,16 +17,31 @@ from tqdm import tqdm
|
|||||||
from . import data_loader, u2net
|
from . import data_loader, u2net
|
||||||
|
|
||||||
|
|
||||||
def download(url, fname, path):
|
def download_file_from_google_drive(id, fname, destination):
|
||||||
if os.path.exists(path):
|
if os.path.exists(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
resp = requests.get(url, stream=True)
|
URL = "https://docs.google.com/uc?export=download"
|
||||||
total = int(resp.headers.get("content-length", 0))
|
|
||||||
with open(path, "wb") as file, tqdm(
|
session = requests.Session()
|
||||||
|
response = session.get(URL, params={"id": id}, stream=True)
|
||||||
|
|
||||||
|
token = None
|
||||||
|
for key, value in response.cookies.items():
|
||||||
|
if key.startswith("download_warning"):
|
||||||
|
token = value
|
||||||
|
break
|
||||||
|
|
||||||
|
if token:
|
||||||
|
params = {"id": id, "confirm": token}
|
||||||
|
response = session.get(URL, params=params, stream=True)
|
||||||
|
|
||||||
|
total = int(response.headers.get("content-length", 0))
|
||||||
|
|
||||||
|
with open(destination, "wb") as file, tqdm(
|
||||||
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
||||||
) as bar:
|
) as bar:
|
||||||
for data in resp.iter_content(chunk_size=1024):
|
for data in response.iter_content(chunk_size=1024):
|
||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
|
|
||||||
@ -37,18 +52,14 @@ def load_model(model_name: str = "u2net"):
|
|||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
net = u2net.U2NETP(3, 1)
|
net = u2net.U2NETP(3, 1)
|
||||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
||||||
download(
|
download_file_from_google_drive(
|
||||||
"https://www.dropbox.com/s/usb1fyiuh8as5gi/u2netp.pth?dl=1",
|
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
|
||||||
"u2netp.pth",
|
|
||||||
path,
|
|
||||||
)
|
)
|
||||||
elif model_name == "u2net":
|
elif model_name == "u2net":
|
||||||
net = u2net.U2NET(3, 1)
|
net = u2net.U2NET(3, 1)
|
||||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
||||||
download(
|
download_file_from_google_drive(
|
||||||
"https://www.dropbox.com/s/kdu5mhose1clds0/u2net.pth?dl=1",
|
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
|
||||||
"u2net.pth",
|
|
||||||
path,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Choose between u2net or u2netp", file=sys.stderr)
|
print("Choose between u2net or u2netp", file=sys.stderr)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user