From 3b18bad8587016011fc1ea27d229213ab173a31b Mon Sep 17 00:00:00 2001 From: Daniel Gatis Date: Wed, 19 Jan 2022 11:48:05 -0300 Subject: [PATCH] add onnx --- Dockerfile | 10 +- README.md | 39 +--- rembg/bg.py | 2 - rembg/cli.py | 13 +- rembg/data_loader.py | 326 -------------------------- rembg/detect.py | 219 +++++++----------- rembg/server.py | 11 +- rembg/u2net.py | 541 ------------------------------------------- requirements-cpu.txt | 1 + requirements-gpu.txt | 1 + requirements.txt | 13 +- setup.py | 7 + 12 files changed, 113 insertions(+), 1070 deletions(-) delete mode 100644 rembg/data_loader.py delete mode 100644 rembg/u2net.py create mode 100644 requirements-cpu.txt create mode 100644 requirements-gpu.txt diff --git a/Dockerfile b/Dockerfile index c417829..048fc54 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,16 +2,16 @@ FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04 RUN apt-get update &&\ apt-get install -y --no-install-recommends \ - python3 \ - python3-pip \ - python3-dev \ - build-essential + python3 \ + python3-pip \ + python3-dev \ + build-essential WORKDIR /rembg COPY . . -RUN pip3 install . +RUN GPU=1 pip3 install . # First run to compile AOT & download model RUN rembg pixel.png >/dev/null diff --git a/README.md b/README.md index 646bf7c..2dfd393 100644 --- a/README.md +++ b/README.md @@ -36,38 +36,19 @@ Rembg is a tool to remove images background. That is it. #### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen *** -### Requirements - -* python 3.8 or newer - -* torch and torchvision stable version (https://pytorch.org) - -#### How to install torch/torchvision - -Go to https://pytorch.org and scrool down to `INSTALL PYTORCH` section and follow the instructions. - -For example: -``` -PyTorch Build: Stable (1.7.1) -Your OS: Windows -Package: Pip -Language: Python -CUDA: None -``` - -The install cmd is: -``` -pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html -``` ### Installation -Install it from pypi - +CPU support: ```bash pip install rembg ``` +GPU support: +```bash +GPU=1 pip install rembg +``` + ### Usage as a cli Remove the background from a remote image @@ -85,14 +66,6 @@ Remove the background from all images in a folder rembg -p path/to/input path/to/output ``` -### Add a custom model - -Copy the `custom-model.pth` file to `~/.u2net` and run: - -```bash -curl -s http://input.png | rembg -m custom-model > output.png -``` - ### Usage as a server Start the server diff --git a/rembg/bg.py b/rembg/bg.py index 3911177..bacbadf 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -1,4 +1,3 @@ -import functools import io import numpy as np @@ -68,7 +67,6 @@ def naive_cutout(img, mask): return cutout -@functools.lru_cache(maxsize=None) def get_model(model_name): if model_name == "u2netp": return load_model(model_name="u2netp") diff --git a/rembg/cli.py b/rembg/cli.py index be95ed2..40d7b71 100644 --- a/rembg/cli.py +++ b/rembg/cli.py @@ -10,17 +10,6 @@ from .bg import remove def main(): - model_path = os.environ.get( - "U2NETP_PATH", - os.path.expanduser(os.path.join("~", ".u2net")), - ) - model_choices = [ - os.path.splitext(os.path.basename(x))[0] - for x in set(glob.glob(model_path + "/*")) - ] - - model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"])) - ap = argparse.ArgumentParser() ap.add_argument( @@ -28,7 +17,7 @@ def main(): "--model", default="u2net", type=str, - choices=model_choices, + choices=["u2net", "u2netp", "u2net_human_seg"], help="The model name.", ) diff --git a/rembg/data_loader.py b/rembg/data_loader.py deleted file mode 100644 index 96215eb..0000000 --- a/rembg/data_loader.py +++ /dev/null @@ -1,326 +0,0 @@ -# data loader - -import random - -import matplotlib.pyplot as plt -import numpy as np -import torch -from PIL import Image -from skimage import color, io, transform -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms, utils - - -# ==========================dataset load========================== -class RescaleT: - def __init__(self, output_size): - assert isinstance(output_size, (int, tuple)) - self.output_size = output_size - - def __call__(self, sample): - imidx, image, label = sample["imidx"], sample["image"], sample["label"] - - h, w = image.shape[:2] - - if isinstance(self.output_size, int): - if h > w: - new_h, new_w = self.output_size * h / w, self.output_size - else: - new_h, new_w = self.output_size, self.output_size * w / h - else: - new_h, new_w = self.output_size - - new_h, new_w = int(new_h), int(new_w) - - # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] - # img = transform.resize(image,(new_h,new_w),mode='constant') - # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) - - img = transform.resize( - image, (self.output_size, self.output_size), mode="constant" - ) - lbl = transform.resize( - label, - (self.output_size, self.output_size), - mode="constant", - order=0, - preserve_range=True, - ) - - return {"imidx": imidx, "image": img, "label": lbl} - - -class Rescale: - def __init__(self, output_size): - assert isinstance(output_size, (int, tuple)) - self.output_size = output_size - - def __call__(self, sample): - imidx, image, label = sample["imidx"], sample["image"], sample["label"] - - if random.random() >= 0.5: - image = image[::-1] - label = label[::-1] - - h, w = image.shape[:2] - - if isinstance(self.output_size, int): - if h > w: - new_h, new_w = self.output_size * h / w, self.output_size - else: - new_h, new_w = self.output_size, self.output_size * w / h - else: - new_h, new_w = self.output_size - - new_h, new_w = int(new_h), int(new_w) - - # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] - img = transform.resize(image, (new_h, new_w), mode="constant") - lbl = transform.resize( - label, (new_h, new_w), mode="constant", order=0, preserve_range=True - ) - - return {"imidx": imidx, "image": img, "label": lbl} - - -class RandomCrop: - def __init__(self, output_size): - assert isinstance(output_size, (int, tuple)) - if isinstance(output_size, int): - self.output_size = (output_size, output_size) - else: - assert len(output_size) == 2 - self.output_size = output_size - - def __call__(self, sample): - imidx, image, label = sample["imidx"], sample["image"], sample["label"] - - if random.random() >= 0.5: - image = image[::-1] - label = label[::-1] - - h, w = image.shape[:2] - new_h, new_w = self.output_size - - top = np.random.randint(0, h - new_h) - left = np.random.randint(0, w - new_w) - - image = image[top : top + new_h, left : left + new_w] - label = label[top : top + new_h, left : left + new_w] - - return {"imidx": imidx, "image": image, "label": label} - - -class ToTensor: - """Convert ndarrays in sample to Tensors.""" - - def __call__(self, sample): - - imidx, image, label = sample["imidx"], sample["image"], sample["label"] - - tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) - tmpLbl = np.zeros(label.shape) - - image = image / np.max(image) - if np.max(label) < 1e-6: - label = label - else: - label = label / np.max(label) - - if image.shape[2] == 1: - tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 - tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 - tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 - else: - tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 - tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 - tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 - - tmpLbl[:, :, 0] = label[:, :, 0] - - # change the r,g,b to b,r,g from [0,255] to [0,1] - # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) - tmpImg = tmpImg.transpose((2, 0, 1)) - tmpLbl = label.transpose((2, 0, 1)) - - return { - "imidx": torch.from_numpy(imidx), - "image": torch.from_numpy(tmpImg), - "label": torch.from_numpy(tmpLbl), - } - - -class ToTensorLab: - """Convert ndarrays in sample to Tensors.""" - - def __init__(self, flag=0): - self.flag = flag - - def __call__(self, sample): - - imidx, image, label = sample["imidx"], sample["image"], sample["label"] - - tmpLbl = np.zeros(label.shape) - - if np.max(label) < 1e-6: - label = label - else: - label = label / np.max(label) - - # change the color space - if self.flag == 2: # with rgb and Lab colors - tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) - tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) - if image.shape[2] == 1: - tmpImgt[:, :, 0] = image[:, :, 0] - tmpImgt[:, :, 1] = image[:, :, 0] - tmpImgt[:, :, 2] = image[:, :, 0] - else: - tmpImgt = image - tmpImgtl = color.rgb2lab(tmpImgt) - - # nomalize image to range [0,1] - tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / ( - np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]) - ) - tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / ( - np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]) - ) - tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / ( - np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]) - ) - tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / ( - np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]) - ) - tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / ( - np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]) - ) - tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / ( - np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]) - ) - - # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) - - tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std( - tmpImg[:, :, 0] - ) - tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std( - tmpImg[:, :, 1] - ) - tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std( - tmpImg[:, :, 2] - ) - tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std( - tmpImg[:, :, 3] - ) - tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std( - tmpImg[:, :, 4] - ) - tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std( - tmpImg[:, :, 5] - ) - - elif self.flag == 1: # with Lab color - tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) - - if image.shape[2] == 1: - tmpImg[:, :, 0] = image[:, :, 0] - tmpImg[:, :, 1] = image[:, :, 0] - tmpImg[:, :, 2] = image[:, :, 0] - else: - tmpImg = image - - tmpImg = color.rgb2lab(tmpImg) - - # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) - - tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / ( - np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]) - ) - tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / ( - np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]) - ) - tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / ( - np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]) - ) - - tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std( - tmpImg[:, :, 0] - ) - tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std( - tmpImg[:, :, 1] - ) - tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std( - tmpImg[:, :, 2] - ) - - else: # with rgb color - tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) - image = image / np.max(image) - if image.shape[2] == 1: - tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 - tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 - tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 - else: - tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 - tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 - tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 - - tmpLbl[:, :, 0] = label[:, :, 0] - - # change the r,g,b to b,r,g from [0,255] to [0,1] - # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) - tmpImg = tmpImg.transpose((2, 0, 1)) - tmpLbl = label.transpose((2, 0, 1)) - - return { - "imidx": torch.from_numpy(imidx), - "image": torch.from_numpy(tmpImg), - "label": torch.from_numpy(tmpLbl), - } - - -class SalObjDataset(Dataset): - def __init__(self, img_name_list, lbl_name_list, transform=None): - # self.root_dir = root_dir - # self.image_name_list = glob.glob(image_dir+'*.png') - # self.label_name_list = glob.glob(label_dir+'*.png') - self.image_name_list = img_name_list - self.label_name_list = lbl_name_list - self.transform = transform - - def __len__(self): - return len(self.image_name_list) - - def __getitem__(self, idx): - - # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx]) - # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx]) - - image = io.imread(self.image_name_list[idx]) - imname = self.image_name_list[idx] - imidx = np.array([idx]) - - if 0 == len(self.label_name_list): - label_3 = np.zeros(image.shape) - else: - label_3 = io.imread(self.label_name_list[idx]) - - label = np.zeros(label_3.shape[0:2]) - if 3 == len(label_3.shape): - label = label_3[:, :, 0] - elif 2 == len(label_3.shape): - label = label_3 - - if 3 == len(image.shape) and 2 == len(label.shape): - label = label[:, :, np.newaxis] - elif 2 == len(image.shape) and 2 == len(label.shape): - image = image[:, :, np.newaxis] - label = label[:, :, np.newaxis] - - sample = {"imidx": imidx, "image": image, "label": label} - - if self.transform: - sample = self.transform(sample) - - return sample diff --git a/rembg/detect.py b/rembg/detect.py index b6fbcac..e8e0575 100644 --- a/rembg/detect.py +++ b/rembg/detect.py @@ -1,139 +1,103 @@ -import errno import os import sys -import urllib.request -from hashlib import md5 +import gdown import numpy as np -import requests -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision +import onnxruntime as ort from PIL import Image from skimage import transform -from torchvision import transforms -from tqdm import tqdm -from .data_loader import RescaleT, ToTensorLab -from .u2net import U2NET, U2NETP - - -def download_file_from_google_drive(id, fname, destination): - head, tail = os.path.split(destination) - os.makedirs(head, exist_ok=True) - - URL = "https://docs.google.com/uc?export=download" - - session = requests.Session() - response = session.get(URL, params={"id": id}, stream=True) - - token = None - for key, value in response.cookies.items(): - if key.startswith("download_warning"): - token = value - break - - if token: - params = {"id": id, "confirm": token} - response = session.get(URL, params=params, stream=True) - - total = int(response.headers.get("content-length", 0)) - - with open(destination, "wb") as file, tqdm( - desc=f"Downloading {tail} to {head}", - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for data in response.iter_content(chunk_size=1024): - size = file.write(data) - bar.update(size) +SESSIONS = {} def load_model(model_name: str = "u2net"): - hashfile = lambda f: md5(open(f, "rb").read()).hexdigest() + path = os.environ.get( + "U2NETP_PATH", + os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")), + ) if model_name == "u2netp": - net = U2NETP(3, 1) - path = os.environ.get( - "U2NETP_PATH", - os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), - ) - if ( - not os.path.exists(path) - or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e" - ): - download_file_from_google_drive( - "1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy", - "u2netp.pth", - path, - ) - + md5 = "8e83ca70e441ab06c318d82300c84806" + url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR" elif model_name == "u2net": - net = U2NET(3, 1) - path = os.environ.get( - "U2NET_PATH", - os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), - ) - if ( - not os.path.exists(path) - or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55" - ): - download_file_from_google_drive( - "1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ", - "pth", - path, - ) - + md5 = "60024c5c889badc19c04ad937298a77b" + url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab" elif model_name == "u2net_human_seg": - net = U2NET(3, 1) - path = os.environ.get( - "U2NET_PATH", - os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), - ) - if ( - not os.path.exists(path) - or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a" - ): - download_file_from_google_drive( - "1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P", - "u2net_human_seg.pth", - path, - ) + md5 = "c09ddc2e0104f800e3e1bb4652583d1f" + url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j" else: print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr) - try: - if torch.cuda.is_available(): - net.load_state_dict(torch.load(path)) - net.to(torch.device("cuda")) - else: - net.load_state_dict( - torch.load( - path, - map_location="cpu", - ) - ) - except FileNotFoundError: - raise FileNotFoundError( - errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth" - ) + if SESSIONS.get(md5) is None: + gdown.cached_download(url, path, md5=md5, quiet=False) + SESSIONS[md5] = ort.InferenceSession(path) - net.eval() - - return net + return SESSIONS[md5] def norm_pred(d): - ma = torch.max(d) - mi = torch.min(d) + ma = np.max(d) + mi = np.min(d) dn = (d - mi) / (ma - mi) return dn +def rescale(sample, output_size): + imidx, image, label = sample["imidx"], sample["image"], sample["label"] + + h, w = image.shape[:2] + + if isinstance(output_size, int): + if h > w: + new_h, new_w = output_size * h / w, output_size + else: + new_h, new_w = output_size, output_size * w / h + else: + new_h, new_w = output_size + + new_h, new_w = int(new_h), int(new_w) + + img = transform.resize(image, (output_size, output_size), mode="constant") + lbl = transform.resize( + label, + (output_size, output_size), + mode="constant", + order=0, + preserve_range=True, + ) + + return {"imidx": imidx, "image": img, "label": lbl} + + +def color(sample): + imidx, image, label = sample["imidx"], sample["image"], sample["label"] + + tmpLbl = np.zeros(label.shape) + + if np.max(label) < 1e-6: + label = label + else: + label = label / np.max(label) + + tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) + image = image / np.max(image) + if image.shape[2] == 1: + tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 + else: + tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 + tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 + tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 + + tmpLbl[:, :, 0] = label[:, :, 0] + tmpImg = tmpImg.transpose((2, 0, 1)) + tmpLbl = label.transpose((2, 0, 1)) + + return {"imidx": imidx, "image": tmpImg, "label": tmpLbl} + + def preprocess(image): label_3 = np.zeros(image.shape) label = np.zeros(label_3.shape[0:2]) @@ -149,34 +113,23 @@ def preprocess(image): image = image[:, :, np.newaxis] label = label[:, :, np.newaxis] - transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) - sample = transform({"imidx": np.array([0]), "image": image, "label": label}) + sample = {"imidx": np.array([0]), "image": image, "label": label} + sample = rescale(sample, 320) + sample = color(sample) return sample -def predict(net, item): - +def predict(ort_session, item): sample = preprocess(item) + inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32) - with torch.no_grad(): + ort_inputs = {ort_session.get_inputs()[0].name: inputs_test} + ort_outs = ort_session.run(None, ort_inputs) - if torch.cuda.is_available(): - inputs_test = torch.cuda.FloatTensor( - sample["image"].unsqueeze(0).cuda().float() - ) - else: - inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float()) + d1 = ort_outs[0] + pred = d1[:, 0, :, :] + predict = np.squeeze(norm_pred(pred)) + img = Image.fromarray(predict * 255).convert("RGB") - d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) - - pred = d1[:, 0, :, :] - predict = norm_pred(pred) - - predict = predict.squeeze() - predict_np = predict.cpu().detach().numpy() - img = Image.fromarray(predict_np * 255).convert("RGB") - - del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample - - return img + return img diff --git a/rembg/server.py b/rembg/server.py index a3fecc2..f432820 100644 --- a/rembg/server.py +++ b/rembg/server.py @@ -46,16 +46,7 @@ def index(): height = request.args.get("height", type=int) model = request.values.get("model", type=str, default="u2net") - model_path = os.environ.get( - "U2NETP_PATH", - os.path.expanduser(os.path.join("~", ".u2net")), - ) - model_choices = [ - os.path.splitext(os.path.basename(x))[0] - for x in set(glob.glob(model_path + "/*")) - ] - - model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"])) + model_choices = ["u2net", "u2netp", "u2net_human_seg"] if model not in model_choices: return { diff --git a/rembg/u2net.py b/rembg/u2net.py deleted file mode 100644 index 07f93db..0000000 --- a/rembg/u2net.py +++ /dev/null @@ -1,541 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import models - - -class REBNCONV(nn.Module): - def __init__(self, in_ch=3, out_ch=3, dirate=1): - super().__init__() - - self.conv_s1 = nn.Conv2d( - in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate - ) - self.bn_s1 = nn.BatchNorm2d(out_ch) - self.relu_s1 = nn.ReLU(inplace=True) - - def forward(self, x): - - hx = x - xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) - - return xout - - -## upsample tensor 'src' to have the same spatial size with tensor 'tar' -def _upsample_like(src, tar): - - src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False) - - return src - - -### RSU-7 ### -class RSU7(nn.Module): # UNet07DRES(nn.Module): - def __init__(self, in_ch=3, mid_ch=12, out_ch=3): - super().__init__() - - self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) - - self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) - self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) - - self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) - - self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) - - def forward(self, x): - - hx = x - hxin = self.rebnconvin(hx) - - hx1 = self.rebnconv1(hxin) - hx = self.pool1(hx1) - - hx2 = self.rebnconv2(hx) - hx = self.pool2(hx2) - - hx3 = self.rebnconv3(hx) - hx = self.pool3(hx3) - - hx4 = self.rebnconv4(hx) - hx = self.pool4(hx4) - - hx5 = self.rebnconv5(hx) - hx = self.pool5(hx5) - - hx6 = self.rebnconv6(hx) - - hx7 = self.rebnconv7(hx6) - - hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) - hx6dup = _upsample_like(hx6d, hx5) - - hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) - hx5dup = _upsample_like(hx5d, hx4) - - hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) - hx4dup = _upsample_like(hx4d, hx3) - - hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) - hx3dup = _upsample_like(hx3d, hx2) - - hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) - hx2dup = _upsample_like(hx2d, hx1) - - hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) - - return hx1d + hxin - - -### RSU-6 ### -class RSU6(nn.Module): # UNet06DRES(nn.Module): - def __init__(self, in_ch=3, mid_ch=12, out_ch=3): - super().__init__() - - self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) - - self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) - self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) - - self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) - - self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) - - def forward(self, x): - - hx = x - - hxin = self.rebnconvin(hx) - - hx1 = self.rebnconv1(hxin) - hx = self.pool1(hx1) - - hx2 = self.rebnconv2(hx) - hx = self.pool2(hx2) - - hx3 = self.rebnconv3(hx) - hx = self.pool3(hx3) - - hx4 = self.rebnconv4(hx) - hx = self.pool4(hx4) - - hx5 = self.rebnconv5(hx) - - hx6 = self.rebnconv6(hx5) - - hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) - hx5dup = _upsample_like(hx5d, hx4) - - hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) - hx4dup = _upsample_like(hx4d, hx3) - - hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) - hx3dup = _upsample_like(hx3d, hx2) - - hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) - hx2dup = _upsample_like(hx2d, hx1) - - hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) - - return hx1d + hxin - - -### RSU-5 ### -class RSU5(nn.Module): # UNet05DRES(nn.Module): - def __init__(self, in_ch=3, mid_ch=12, out_ch=3): - super().__init__() - - self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) - - self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) - self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) - - self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) - - self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) - - def forward(self, x): - - hx = x - - hxin = self.rebnconvin(hx) - - hx1 = self.rebnconv1(hxin) - hx = self.pool1(hx1) - - hx2 = self.rebnconv2(hx) - hx = self.pool2(hx2) - - hx3 = self.rebnconv3(hx) - hx = self.pool3(hx3) - - hx4 = self.rebnconv4(hx) - - hx5 = self.rebnconv5(hx4) - - hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) - hx4dup = _upsample_like(hx4d, hx3) - - hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) - hx3dup = _upsample_like(hx3d, hx2) - - hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) - hx2dup = _upsample_like(hx2d, hx1) - - hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) - - return hx1d + hxin - - -### RSU-4 ### -class RSU4(nn.Module): # UNet04DRES(nn.Module): - def __init__(self, in_ch=3, mid_ch=12, out_ch=3): - super().__init__() - - self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) - - self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) - self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) - self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) - - self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) - - self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) - self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) - - def forward(self, x): - - hx = x - - hxin = self.rebnconvin(hx) - - hx1 = self.rebnconv1(hxin) - hx = self.pool1(hx1) - - hx2 = self.rebnconv2(hx) - hx = self.pool2(hx2) - - hx3 = self.rebnconv3(hx) - - hx4 = self.rebnconv4(hx3) - - hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) - hx3dup = _upsample_like(hx3d, hx2) - - hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) - hx2dup = _upsample_like(hx2d, hx1) - - hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) - - return hx1d + hxin - - -### RSU-4F ### -class RSU4F(nn.Module): # UNet04FRES(nn.Module): - def __init__(self, in_ch=3, mid_ch=12, out_ch=3): - super().__init__() - - self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) - - self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) - self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) - self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) - - self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) - - self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) - self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) - self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) - - def forward(self, x): - - hx = x - - hxin = self.rebnconvin(hx) - - hx1 = self.rebnconv1(hxin) - hx2 = self.rebnconv2(hx1) - hx3 = self.rebnconv3(hx2) - - hx4 = self.rebnconv4(hx3) - - hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) - hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) - hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) - - return hx1d + hxin - - -##### U^2-Net #### -class U2NET(nn.Module): - def __init__(self, in_ch=3, out_ch=1): - super().__init__() - - self.stage1 = RSU7(in_ch, 32, 64) - self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage2 = RSU6(64, 32, 128) - self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage3 = RSU5(128, 64, 256) - self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage4 = RSU4(256, 128, 512) - self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage5 = RSU4F(512, 256, 512) - self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage6 = RSU4F(512, 256, 512) - - # decoder - self.stage5d = RSU4F(1024, 256, 512) - self.stage4d = RSU4(1024, 128, 256) - self.stage3d = RSU5(512, 64, 128) - self.stage2d = RSU6(256, 32, 64) - self.stage1d = RSU7(128, 16, 64) - - self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) - self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) - self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) - self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) - - self.outconv = nn.Conv2d(6, out_ch, 1) - - def forward(self, x): - - hx = x - - # stage 1 - hx1 = self.stage1(hx) - hx = self.pool12(hx1) - - # stage 2 - hx2 = self.stage2(hx) - hx = self.pool23(hx2) - - # stage 3 - hx3 = self.stage3(hx) - hx = self.pool34(hx3) - - # stage 4 - hx4 = self.stage4(hx) - hx = self.pool45(hx4) - - # stage 5 - hx5 = self.stage5(hx) - hx = self.pool56(hx5) - - # stage 6 - hx6 = self.stage6(hx) - hx6up = _upsample_like(hx6, hx5) - - # -------------------- decoder -------------------- - hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) - hx5dup = _upsample_like(hx5d, hx4) - - hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) - hx4dup = _upsample_like(hx4d, hx3) - - hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) - hx3dup = _upsample_like(hx3d, hx2) - - hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) - hx2dup = _upsample_like(hx2d, hx1) - - hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) - - # side output - d1 = self.side1(hx1d) - - d2 = self.side2(hx2d) - d2 = _upsample_like(d2, d1) - - d3 = self.side3(hx3d) - d3 = _upsample_like(d3, d1) - - d4 = self.side4(hx4d) - d4 = _upsample_like(d4, d1) - - d5 = self.side5(hx5d) - d5 = _upsample_like(d5, d1) - - d6 = self.side6(hx6) - d6 = _upsample_like(d6, d1) - - d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) - - return ( - torch.sigmoid(d0), - torch.sigmoid(d1), - torch.sigmoid(d2), - torch.sigmoid(d3), - torch.sigmoid(d4), - torch.sigmoid(d5), - torch.sigmoid(d6), - ) - - -### U^2-Net small ### -class U2NETP(nn.Module): - def __init__(self, in_ch=3, out_ch=1): - super().__init__() - - self.stage1 = RSU7(in_ch, 16, 64) - self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage2 = RSU6(64, 16, 64) - self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage3 = RSU5(64, 16, 64) - self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage4 = RSU4(64, 16, 64) - self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage5 = RSU4F(64, 16, 64) - self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) - - self.stage6 = RSU4F(64, 16, 64) - - # decoder - self.stage5d = RSU4F(128, 16, 64) - self.stage4d = RSU4(128, 16, 64) - self.stage3d = RSU5(128, 16, 64) - self.stage2d = RSU6(128, 16, 64) - self.stage1d = RSU7(128, 16, 64) - - self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side3 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side4 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side5 = nn.Conv2d(64, out_ch, 3, padding=1) - self.side6 = nn.Conv2d(64, out_ch, 3, padding=1) - - self.outconv = nn.Conv2d(6, out_ch, 1) - - def forward(self, x): - - hx = x - - # stage 1 - hx1 = self.stage1(hx) - hx = self.pool12(hx1) - - # stage 2 - hx2 = self.stage2(hx) - hx = self.pool23(hx2) - - # stage 3 - hx3 = self.stage3(hx) - hx = self.pool34(hx3) - - # stage 4 - hx4 = self.stage4(hx) - hx = self.pool45(hx4) - - # stage 5 - hx5 = self.stage5(hx) - hx = self.pool56(hx5) - - # stage 6 - hx6 = self.stage6(hx) - hx6up = _upsample_like(hx6, hx5) - - # decoder - hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) - hx5dup = _upsample_like(hx5d, hx4) - - hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) - hx4dup = _upsample_like(hx4d, hx3) - - hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) - hx3dup = _upsample_like(hx3d, hx2) - - hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) - hx2dup = _upsample_like(hx2d, hx1) - - hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) - - # side output - d1 = self.side1(hx1d) - - d2 = self.side2(hx2d) - d2 = _upsample_like(d2, d1) - - d3 = self.side3(hx3d) - d3 = _upsample_like(d3, d1) - - d4 = self.side4(hx4d) - d4 = _upsample_like(d4, d1) - - d5 = self.side5(hx5d) - d5 = _upsample_like(d5, d1) - - d6 = self.side6(hx6) - d6 = _upsample_like(d6, d1) - - d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) - - return ( - torch.sigmoid(d0), - torch.sigmoid(d1), - torch.sigmoid(d2), - torch.sigmoid(d3), - torch.sigmoid(d4), - torch.sigmoid(d5), - torch.sigmoid(d6), - ) diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000..7be989e --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1 @@ +onnxruntime==1.10.0 diff --git a/requirements-gpu.txt b/requirements-gpu.txt new file mode 100644 index 0000000..3710249 --- /dev/null +++ b/requirements-gpu.txt @@ -0,0 +1 @@ +onnxruntime-gpu==1.10.0 diff --git a/requirements.txt b/requirements.txt index f2e89b8..281e61a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,10 @@ +filetype==1.0.7 flask==1.1.2 +gdown==4.2.0 numpy==1.20.0 pillow==8.3.2 +pymatting==1.1.5 scikit-image==0.19.1 -torch==1.9.1 -torchvision==0.10.1 -waitress==1.4.4 -tqdm==4.51.0 -requests==2.24.0 scipy==1.5.4 -pymatting==1.1.1 -filetype==1.0.7 -matplotlib==3.5.1 +tqdm==4.51.0 +waitress==1.4.4 diff --git a/setup.py b/setup.py index d4fb7a4..54dedf1 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,13 @@ long_description = (here / "README.md").read_text(encoding="utf-8") with open("requirements.txt") as f: requireds = f.read().splitlines() +if os.getenv("GPU") is None: + with open("requirements-cpu.txt") as f: + requireds += f.read().splitlines() +else: + with open("requirements-gpu.txt") as f: + requireds += f.read().splitlines() + setup( name="rembg", description="Remove image background",