mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-20 21:39:14 +08:00
add model var envs
This commit is contained in:
parent
0a96555440
commit
58fd707872
@ -10,3 +10,4 @@ requests==2.24.0
|
|||||||
scipy==1.5.4
|
scipy==1.5.4
|
||||||
pymatting==1.1.1
|
pymatting==1.1.1
|
||||||
filetype==1.0.7
|
filetype==1.0.7
|
||||||
|
hsh==1.1.0
|
||||||
|
2
setup.py
2
setup.py
@ -11,7 +11,7 @@ with open("requirements.txt") as f:
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="rembg",
|
name="rembg",
|
||||||
version="1.0.16",
|
version="1.0.17",
|
||||||
description="Remove image background",
|
description="Remove image background",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
import functools
|
||||||
import io
|
import io
|
||||||
|
|
||||||
import functools
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
||||||
@ -12,7 +12,11 @@ from .u2net import detect
|
|||||||
|
|
||||||
|
|
||||||
def alpha_matting_cutout(
|
def alpha_matting_cutout(
|
||||||
img, mask, foreground_threshold, background_threshold, erode_structure_size,
|
img,
|
||||||
|
mask,
|
||||||
|
foreground_threshold,
|
||||||
|
background_threshold,
|
||||||
|
erode_structure_size,
|
||||||
):
|
):
|
||||||
base_size = (1000, 1000)
|
base_size = (1000, 1000)
|
||||||
size = img.size
|
size = img.size
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import filetype
|
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
|
||||||
|
import filetype
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..bg import remove
|
from ..bg import remove
|
||||||
@ -55,7 +56,10 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-p", "--path", nargs="+", help="Path of a file or a folder of files.",
|
"-p",
|
||||||
|
"--path",
|
||||||
|
nargs="+",
|
||||||
|
help="Path of a file or a folder of files.",
|
||||||
)
|
)
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@ -95,7 +99,7 @@ def main():
|
|||||||
|
|
||||||
if fi_type is None:
|
if fi_type is None:
|
||||||
continue
|
continue
|
||||||
elif fi_type.mime.find('image') < 0:
|
elif fi_type.mime.find("image") < 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with open(fi, "rb") as input:
|
with open(fi, "rb") as input:
|
||||||
|
@ -36,7 +36,10 @@ def index():
|
|||||||
return {"error": "invalid query param 'model'"}, 400
|
return {"error": "invalid query param 'model'"}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return send_file(BytesIO(remove(file_content, model)), mimetype="image/png",)
|
return send_file(
|
||||||
|
BytesIO(remove(file_content, model)),
|
||||||
|
mimetype="image/png",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
app.logger.exception(e, exc_info=True)
|
app.logger.exception(e, exc_info=True)
|
||||||
return {"error": "oops, something went wrong!"}, 500
|
return {"error": "oops, something went wrong!"}, 500
|
||||||
@ -46,11 +49,19 @@ def main():
|
|||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-a", "--addr", default="0.0.0.0", type=str, help="The IP address to bind to.",
|
"-a",
|
||||||
|
"--addr",
|
||||||
|
default="0.0.0.0",
|
||||||
|
type=str,
|
||||||
|
help="The IP address to bind to.",
|
||||||
)
|
)
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-p", "--port", default=5000, type=int, help="The port to bind to.",
|
"-p",
|
||||||
|
"--port",
|
||||||
|
default=5000,
|
||||||
|
type=int,
|
||||||
|
help="The port to bind to.",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
@ -9,6 +9,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from hsh.library.hash import Hasher
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from skimage import transform
|
from skimage import transform
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
@ -18,8 +19,8 @@ from . import data_loader, u2net
|
|||||||
|
|
||||||
|
|
||||||
def download_file_from_google_drive(id, fname, destination):
|
def download_file_from_google_drive(id, fname, destination):
|
||||||
if os.path.exists(destination):
|
head, tail = os.path.split(destination)
|
||||||
return
|
os.makedirs(head, exist_ok=True)
|
||||||
|
|
||||||
URL = "https://docs.google.com/uc?export=download"
|
URL = "https://docs.google.com/uc?export=download"
|
||||||
|
|
||||||
@ -39,7 +40,11 @@ def download_file_from_google_drive(id, fname, destination):
|
|||||||
total = int(response.headers.get("content-length", 0))
|
total = int(response.headers.get("content-length", 0))
|
||||||
|
|
||||||
with open(destination, "wb") as file, tqdm(
|
with open(destination, "wb") as file, tqdm(
|
||||||
desc=fname, total=total, unit="iB", unit_scale=True, unit_divisor=1024,
|
desc=f"Downloading {tail} to {head}",
|
||||||
|
total=total,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
) as bar:
|
) as bar:
|
||||||
for data in response.iter_content(chunk_size=1024):
|
for data in response.iter_content(chunk_size=1024):
|
||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
@ -47,20 +52,39 @@ def download_file_from_google_drive(id, fname, destination):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str = "u2net"):
|
def load_model(model_name: str = "u2net"):
|
||||||
os.makedirs(os.path.expanduser(os.path.join("~", ".u2net")), exist_ok=True)
|
hasher = Hasher()
|
||||||
|
|
||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
net = u2net.U2NETP(3, 1)
|
net = u2net.U2NETP(3, 1)
|
||||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
path = os.environ.get(
|
||||||
download_file_from_google_drive(
|
"U2NETP_PATH",
|
||||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", "u2netp.pth", path,
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
not os.path.exists(path)
|
||||||
|
or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
||||||
|
):
|
||||||
|
download_file_from_google_drive(
|
||||||
|
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
||||||
|
"u2netp.pth",
|
||||||
|
path,
|
||||||
|
)
|
||||||
|
|
||||||
elif model_name == "u2net":
|
elif model_name == "u2net":
|
||||||
net = u2net.U2NET(3, 1)
|
net = u2net.U2NET(3, 1)
|
||||||
path = os.path.expanduser(os.path.join("~", ".u2net", model_name))
|
path = os.environ.get(
|
||||||
download_file_from_google_drive(
|
"U2NET_PATH",
|
||||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", "u2net.pth", path,
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
not os.path.exists(path)
|
||||||
|
or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
||||||
|
):
|
||||||
|
download_file_from_google_drive(
|
||||||
|
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
||||||
|
"u2net.pth",
|
||||||
|
path,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Choose between u2net or u2netp", file=sys.stderr)
|
print("Choose between u2net or u2netp", file=sys.stderr)
|
||||||
|
|
||||||
@ -69,7 +93,12 @@ def load_model(model_name: str = "u2net"):
|
|||||||
net.load_state_dict(torch.load(path))
|
net.load_state_dict(torch.load(path))
|
||||||
net.to(torch.device("cuda"))
|
net.to(torch.device("cuda"))
|
||||||
else:
|
else:
|
||||||
net.load_state_dict(torch.load(path, map_location="cpu",))
|
net.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
path,
|
||||||
|
map_location="cpu",
|
||||||
|
)
|
||||||
|
)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user