mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-20 14:59:06 +08:00
add model var envs
This commit is contained in:
parent
0a96555440
commit
58fd707872
@ -10,3 +10,4 @@ requests==2.24.0
|
||||
scipy==1.5.4
|
||||
pymatting==1.1.1
|
||||
filetype==1.0.7
|
||||
hsh==1.1.0
|
||||
|
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
||||
|
||||
setup(
|
||||
name="rembg",
|
||||
version="1.0.16",
|
||||
version="1.0.17",
|
||||
description="Remove image background",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
|
@ -1,6 +1,6 @@
|
||||
import functools
|
||||
import io
|
||||
|
||||
import functools
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||
@ -12,7 +12,11 @@ from .u2net import detect
|
||||
|
||||
|
||||
def alpha_matting_cutout(
|
||||
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
||||
img,
|
||||
mask,
|
||||
foreground_threshold,
|
||||
background_threshold,
|
||||
erode_structure_size,
|
||||
):
|
||||
base_size = (1000, 1000)
|
||||
size = img.size
|
||||
|
@ -1,8 +1,9 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import filetype
|
||||
from distutils.util import strtobool
|
||||
|
||||
import filetype
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..bg import remove
|
||||
@ -55,7 +56,10 @@ def main():
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
|
||||
"-p",
|
||||
"--path",
|
||||
nargs="+",
|
||||
help="Path of a file or a folder of files.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
@ -95,7 +99,7 @@ def main():
|
||||
|
||||
if fi_type is None:
|
||||
continue
|
||||
elif fi_type.mime.find('image') < 0:
|
||||
elif fi_type.mime.find("image") < 0:
|
||||
continue
|
||||
|
||||
with open(fi, "rb") as input:
|
||||
|
@ -36,7 +36,10 @@ def index():
|
||||
return {"error": "invalid query param 'model'"}, 400
|
||||
|
||||
try:
|
||||
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
|
||||
return send_file(
|
||||
BytesIO(remove(file_content, model)),
|
||||
mimetype="image/png",
|
||||
)
|
||||
except Exception as e:
|
||||
app.logger.exception(e, exc_info=True)
|
||||
return {"error": "oops, something went wrong!"}, 500
|
||||
@ -46,11 +49,19 @@ def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
|
||||
ap.add_argument(
|
||||
"-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.",
|
||||
"-a",
|
||||
"--addr",
|
||||
default="0.0.0.0",
|
||||
type=str,
|
||||
help="The IP address to bind to.",
|
||||
)
|
||||
|
||||
ap.add_argument(
|
||||
"-p", "--port", default=5000, type=int, help="The port to bind to.",
|
||||
"-p",
|
||||
"--port",
|
||||
default=5000,
|
||||
type=int,
|
||||
help="The port to bind to.",
|
||||
)
|
||||
|
||||
args = ap.parse_args()
|
||||
|
@ -9,6 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from hsh.library.hash import Hasher
|
||||
from PIL import Image
|
||||
from skimage import transform
|
||||
from torchvision import transforms
|
||||
@ -18,8 +19,8 @@ from . import data_loader, u2net
|
||||
|
||||
|
||||
def download_file_from_google_drive(id, fname, destination):
|
||||
if os.path.exists(destination):
|
||||
return
|
||||
head, tail = os.path.split(destination)
|
||||
os.makedirs(head, exist_ok=True)
|
||||
|
||||
URL = "https://docs.google.com/uc?export=download"
|
||||
|
||||
@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination):
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
|
||||
with open(destination, "wb") as file, tqdm(
|
||||
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
||||
desc=f"Downloading {tail} to {head}",
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination):
|
||||
|
||||
|
||||
def load_model(model_name: str = "u2net"):
|
||||
os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
|
||||
hasher = Hasher()
|
||||
|
||||
if model_name == "u2netp":
|
||||
net = u2net.U2NETP(3, 1)
|
||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
||||
download_file_from_google_drive(
|
||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
|
||||
path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
||||
"u2netp.pth",
|
||||
path,
|
||||
)
|
||||
|
||||
elif model_name == "u2net":
|
||||
net = u2net.U2NET(3, 1)
|
||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
||||
download_file_from_google_drive(
|
||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
|
||||
path = os.environ.get(
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
||||
"u2net.pth",
|
||||
path,
|
||||
)
|
||||
else:
|
||||
print("Choose between u2net or u2netp", file=sys.stderr)
|
||||
|
||||
@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"):
|
||||
net.load_state_dict(torch.load(path))
|
||||
net.to(torch.device("cuda"))
|
||||
else:
|
||||
net.load_state_dict(torch.load(path, map_location="cpu",))
|
||||
net.load_state_dict(
|
||||
torch.load(
|
||||
path,
|
||||
map_location="cpu",
|
||||
)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(
|
||||
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|
||||
|
Loading…
x
Reference in New Issue
Block a user