mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-15 01:15:58 +08:00
fix project layout
This commit is contained in:
parent
39e45f2e86
commit
0fd1236db4
10
rembg/bg.py
10
rembg/bg.py
@ -8,7 +8,7 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|||||||
from pymatting.util.util import stack_images
|
from pymatting.util.util import stack_images
|
||||||
from scipy.ndimage.morphology import binary_erosion
|
from scipy.ndimage.morphology import binary_erosion
|
||||||
|
|
||||||
from .u2net import detect
|
from .detect import load_model, predict
|
||||||
|
|
||||||
|
|
||||||
def alpha_matting_cutout(
|
def alpha_matting_cutout(
|
||||||
@ -71,11 +71,11 @@ def naive_cutout(img, mask):
|
|||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def get_model(model_name):
|
def get_model(model_name):
|
||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
return detect.load_model(model_name="u2netp")
|
return load_model(model_name="u2netp")
|
||||||
if model_name == "u2net_human_seg":
|
if model_name == "u2net_human_seg":
|
||||||
return detect.load_model(model_name="u2net_human_seg")
|
return load_model(model_name="u2net_human_seg")
|
||||||
else:
|
else:
|
||||||
return detect.load_model(model_name="u2net")
|
return load_model(model_name="u2net")
|
||||||
|
|
||||||
|
|
||||||
def resize_image(img, width, height):
|
def resize_image(img, width, height):
|
||||||
@ -105,7 +105,7 @@ def remove(
|
|||||||
img = resize_image(img, width, height)
|
img = resize_image(img, width, height)
|
||||||
|
|
||||||
model = get_model(model_name)
|
model = get_model(model_name)
|
||||||
mask = detect.predict(model, np.array(img)).convert("L")
|
mask = predict(model, np.array(img)).convert("L")
|
||||||
|
|
||||||
if alpha_matting:
|
if alpha_matting:
|
||||||
try:
|
try:
|
||||||
|
@ -15,8 +15,8 @@ from skimage import transform
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import data_loader, u2net
|
from .data_loader import RescaleT, ToTensorLab
|
||||||
|
from .u2net import U2NETP, U2NET
|
||||||
|
|
||||||
def download_file_from_google_drive(id, fname, destination):
|
def download_file_from_google_drive(id, fname, destination):
|
||||||
head, tail = os.path.split(destination)
|
head, tail = os.path.split(destination)
|
||||||
@ -55,7 +55,7 @@ def load_model(model_name: str = "u2net"):
|
|||||||
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
|
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
|
||||||
|
|
||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
net = u2net.U2NETP(3, 1)
|
net = U2NETP(3, 1)
|
||||||
path = os.environ.get(
|
path = os.environ.get(
|
||||||
"U2NETP_PATH",
|
"U2NETP_PATH",
|
||||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||||
@ -71,7 +71,7 @@ def load_model(model_name: str = "u2net"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif model_name == "u2net":
|
elif model_name == "u2net":
|
||||||
net = u2net.U2NET(3, 1)
|
net = U2NET(3, 1)
|
||||||
path = os.environ.get(
|
path = os.environ.get(
|
||||||
"U2NET_PATH",
|
"U2NET_PATH",
|
||||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||||
@ -82,12 +82,12 @@ def load_model(model_name: str = "u2net"):
|
|||||||
):
|
):
|
||||||
download_file_from_google_drive(
|
download_file_from_google_drive(
|
||||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
||||||
"u2net.pth",
|
"pth",
|
||||||
path,
|
path,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_name == "u2net_human_seg":
|
elif model_name == "u2net_human_seg":
|
||||||
net = u2net.U2NET(3, 1)
|
net = U2NET(3, 1)
|
||||||
path = os.environ.get(
|
path = os.environ.get(
|
||||||
"U2NET_PATH",
|
"U2NET_PATH",
|
||||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||||
@ -149,7 +149,7 @@ def preprocess(image):
|
|||||||
label = label[:, :, np.newaxis]
|
label = label[:, :, np.newaxis]
|
||||||
|
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
[data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)]
|
[RescaleT(320), ToTensorLab(flag=0)]
|
||||||
)
|
)
|
||||||
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
|
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user