diff --git a/rembg/bg.py b/rembg/bg.py index 6354a03..3911177 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -8,7 +8,7 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from pymatting.util.util import stack_images from scipy.ndimage.morphology import binary_erosion -from .u2net import detect +from .detect import load_model, predict def alpha_matting_cutout( @@ -71,11 +71,11 @@ def naive_cutout(img, mask): @functools.lru_cache(maxsize=None) def get_model(model_name): if model_name == "u2netp": - return detect.load_model(model_name="u2netp") + return load_model(model_name="u2netp") if model_name == "u2net_human_seg": - return detect.load_model(model_name="u2net_human_seg") + return load_model(model_name="u2net_human_seg") else: - return detect.load_model(model_name="u2net") + return load_model(model_name="u2net") def resize_image(img, width, height): @@ -105,7 +105,7 @@ def remove( img = resize_image(img, width, height) model = get_model(model_name) - mask = detect.predict(model, np.array(img)).convert("L") + mask = predict(model, np.array(img)).convert("L") if alpha_matting: try: diff --git a/rembg/u2net/data_loader.py b/rembg/data_loader.py similarity index 100% rename from rembg/u2net/data_loader.py rename to rembg/data_loader.py diff --git a/rembg/u2net/detect.py b/rembg/detect.py similarity index 95% rename from rembg/u2net/detect.py rename to rembg/detect.py index 62c442d..9b73c8d 100644 --- a/rembg/u2net/detect.py +++ b/rembg/detect.py @@ -15,8 +15,8 @@ from skimage import transform from torchvision import transforms from tqdm import tqdm -from . import data_loader, u2net - +from .data_loader import RescaleT, ToTensorLab +from .u2net import U2NETP, U2NET def download_file_from_google_drive(id, fname, destination): head, tail = os.path.split(destination) @@ -55,7 +55,7 @@ def load_model(model_name: str = "u2net"): hashfile = lambda f: md5(open(f, "rb").read()).hexdigest() if model_name == "u2netp": - net = u2net.U2NETP(3, 1) + net = U2NETP(3, 1) path = os.environ.get( "U2NETP_PATH", os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), @@ -71,7 +71,7 @@ def load_model(model_name: str = "u2net"): ) elif model_name == "u2net": - net = u2net.U2NET(3, 1) + net = U2NET(3, 1) path = os.environ.get( "U2NET_PATH", os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), @@ -82,12 +82,12 @@ def load_model(model_name: str = "u2net"): ): download_file_from_google_drive( "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", - "u2net.pth", + "pth", path, ) elif model_name == "u2net_human_seg": - net = u2net.U2NET(3, 1) + net = U2NET(3, 1) path = os.environ.get( "U2NET_PATH", os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), @@ -149,7 +149,7 @@ def preprocess(image): label = label[:, :, np.newaxis] transform = transforms.Compose( - [data_loader.RescaleT(320), data_loader.ToTensorLab(flag=0)] + [RescaleT(320), ToTensorLab(flag=0)] ) sample = transform({"imidx": np.array([0]), "image": image, "label": label}) diff --git a/rembg/u2net/u2net.py b/rembg/u2net.py similarity index 100% rename from rembg/u2net/u2net.py rename to rembg/u2net.py diff --git a/rembg/u2net/__init__.py b/rembg/u2net/__init__.py deleted file mode 100644 index e69de29..0000000