This commit is contained in:
Daniel Gatis 2022-01-19 11:48:05 -03:00
parent 1ab362eb38
commit 3b18bad858
12 changed files with 113 additions and 1070 deletions

View File

@ -2,16 +2,16 @@ FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04
RUN apt-get update &&\ RUN apt-get update &&\
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
python3 \ python3 \
python3-pip \ python3-pip \
python3-dev \ python3-dev \
build-essential build-essential
WORKDIR /rembg WORKDIR /rembg
COPY . . COPY . .
RUN pip3 install . RUN GPU=1 pip3 install .
# First run to compile AOT & download model # First run to compile AOT & download model
RUN rembg pixel.png >/dev/null RUN rembg pixel.png >/dev/null

View File

@ -36,38 +36,19 @@ Rembg is a tool to remove images background. That is it.
#### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen *** #### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen ***
### Requirements
* python 3.8 or newer
* torch and torchvision stable version (https://pytorch.org)
#### How to install torch/torchvision
Go to https://pytorch.org and scrool down to `INSTALL PYTORCH` section and follow the instructions.
For example:
```
PyTorch Build: Stable (1.7.1)
Your OS: Windows
Package: Pip
Language: Python
CUDA: None
```
The install cmd is:
```
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
```
### Installation ### Installation
Install it from pypi CPU support:
```bash ```bash
pip install rembg pip install rembg
``` ```
GPU support:
```bash
GPU=1 pip install rembg
```
### Usage as a cli ### Usage as a cli
Remove the background from a remote image Remove the background from a remote image
@ -85,14 +66,6 @@ Remove the background from all images in a folder
rembg -p path/to/input path/to/output rembg -p path/to/input path/to/output
``` ```
### Add a custom model
Copy the `custom-model.pth` file to `~/.u2net` and run:
```bash
curl -s http://input.png | rembg -m custom-model > output.png
```
### Usage as a server ### Usage as a server
Start the server Start the server

View File

@ -1,4 +1,3 @@
import functools
import io import io
import numpy as np import numpy as np
@ -68,7 +67,6 @@ def naive_cutout(img, mask):
return cutout return cutout
@functools.lru_cache(maxsize=None)
def get_model(model_name): def get_model(model_name):
if model_name == "u2netp": if model_name == "u2netp":
return load_model(model_name="u2netp") return load_model(model_name="u2netp")

View File

@ -10,17 +10,6 @@ from .bg import remove
def main(): def main():
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [
os.path.splitext(os.path.basename(x))[0]
for x in set(glob.glob(model_path + "/*"))
]
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument( ap.add_argument(
@ -28,7 +17,7 @@ def main():
"--model", "--model",
default="u2net", default="u2net",
type=str, type=str,
choices=model_choices, choices=["u2net", "u2netp", "u2net_human_seg"],
help="The model name.", help="The model name.",
) )

View File

@ -1,326 +0,0 @@
# data loader
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from skimage import color, io, transform
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
# ==========================dataset load==========================
class RescaleT:
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(
image, (self.output_size, self.output_size), mode="constant"
)
lbl = transform.resize(
label,
(self.output_size, self.output_size),
mode="constant",
order=0,
preserve_range=True,
)
return {"imidx": imidx, "image": img, "label": lbl}
class Rescale:
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image, (new_h, new_w), mode="constant")
lbl = transform.resize(
label, (new_h, new_w), mode="constant", order=0, preserve_range=True
)
return {"imidx": imidx, "image": img, "label": lbl}
class RandomCrop:
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top : top + new_h, left : left + new_w]
label = label[top : top + new_h, left : left + new_w]
return {"imidx": imidx, "image": image, "label": label}
class ToTensor:
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
tmpLbl = np.zeros(label.shape)
image = image / np.max(image)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {
"imidx": torch.from_numpy(imidx),
"image": torch.from_numpy(tmpImg),
"label": torch.from_numpy(tmpLbl),
}
class ToTensorLab:
"""Convert ndarrays in sample to Tensors."""
def __init__(self, flag=0):
self.flag = flag
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpLbl = np.zeros(label.shape)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImgt[:, :, 0] = image[:, :, 0]
tmpImgt[:, :, 1] = image[:, :, 0]
tmpImgt[:, :, 2] = image[:, :, 0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
)
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
)
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
)
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
)
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
)
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
tmpImg[:, :, 0]
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
tmpImg[:, :, 1]
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
tmpImg[:, :, 2]
)
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
tmpImg[:, :, 3]
)
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
tmpImg[:, :, 4]
)
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
tmpImg[:, :, 5]
)
elif self.flag == 1: # with Lab color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImg[:, :, 0] = image[:, :, 0]
tmpImg[:, :, 1] = image[:, :, 0]
tmpImg[:, :, 2] = image[:, :, 0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
)
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
tmpImg[:, :, 0]
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
tmpImg[:, :, 1]
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
tmpImg[:, :, 2]
)
else: # with rgb color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {
"imidx": torch.from_numpy(imidx),
"image": torch.from_numpy(tmpImg),
"label": torch.from_numpy(tmpLbl),
}
class SalObjDataset(Dataset):
def __init__(self, img_name_list, lbl_name_list, transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self, idx):
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
if 0 == len(self.label_name_list):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {"imidx": imidx, "image": image, "label": label}
if self.transform:
sample = self.transform(sample)
return sample

View File

@ -1,139 +1,103 @@
import errno
import os import os
import sys import sys
import urllib.request
from hashlib import md5
import gdown
import numpy as np import numpy as np
import requests import onnxruntime as ort
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image from PIL import Image
from skimage import transform from skimage import transform
from torchvision import transforms
from tqdm import tqdm
from .data_loader import RescaleT, ToTensorLab SESSIONS = {}
from .u2net import U2NET, U2NETP
def download_file_from_google_drive(id, fname, destination):
head, tail = os.path.split(destination)
os.makedirs(head, exist_ok=True)
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={"id": id}, stream=True)
token = None
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
if token:
params = {"id": id, "confirm": token}
response = session.get(URL, params=params, stream=True)
total = int(response.headers.get("content-length", 0))
with open(destination, "wb") as file, tqdm(
desc=f"Downloading {tail} to {head}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
def load_model(model_name: str = "u2net"): def load_model(model_name: str = "u2net"):
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest() path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
)
if model_name == "u2netp": if model_name == "u2netp":
net = U2NETP(3, 1) md5 = "8e83ca70e441ab06c318d82300c84806"
path = os.environ.get( url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
download_file_from_google_drive(
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
"u2netp.pth",
path,
)
elif model_name == "u2net": elif model_name == "u2net":
net = U2NET(3, 1) md5 = "60024c5c889badc19c04ad937298a77b"
path = os.environ.get( url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
download_file_from_google_drive(
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
"pth",
path,
)
elif model_name == "u2net_human_seg": elif model_name == "u2net_human_seg":
net = U2NET(3, 1) md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
path = os.environ.get( url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a"
):
download_file_from_google_drive(
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
"u2net_human_seg.pth",
path,
)
else: else:
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr) print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
try: if SESSIONS.get(md5) is None:
if torch.cuda.is_available(): gdown.cached_download(url, path, md5=md5, quiet=False)
net.load_state_dict(torch.load(path)) SESSIONS[md5] = ort.InferenceSession(path)
net.to(torch.device("cuda"))
else:
net.load_state_dict(
torch.load(
path,
map_location="cpu",
)
)
except FileNotFoundError:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
)
net.eval() return SESSIONS[md5]
return net
def norm_pred(d): def norm_pred(d):
ma = torch.max(d) ma = np.max(d)
mi = torch.min(d) mi = np.min(d)
dn = (d - mi) / (ma - mi) dn = (d - mi) / (ma - mi)
return dn return dn
def rescale(sample, output_size):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
h, w = image.shape[:2]
if isinstance(output_size, int):
if h > w:
new_h, new_w = output_size * h / w, output_size
else:
new_h, new_w = output_size, output_size * w / h
else:
new_h, new_w = output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (output_size, output_size), mode="constant")
lbl = transform.resize(
label,
(output_size, output_size),
mode="constant",
order=0,
preserve_range=True,
)
return {"imidx": imidx, "image": img, "label": lbl}
def color(sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpLbl = np.zeros(label.shape)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
def preprocess(image): def preprocess(image):
label_3 = np.zeros(image.shape) label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2]) label = np.zeros(label_3.shape[0:2])
@ -149,34 +113,23 @@ def preprocess(image):
image = image[:, :, np.newaxis] image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis] label = label[:, :, np.newaxis]
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)]) sample = {"imidx": np.array([0]), "image": image, "label": label}
sample = transform({"imidx": np.array([0]), "image": image, "label": label}) sample = rescale(sample, 320)
sample = color(sample)
return sample return sample
def predict(net, item): def predict(ort_session, item):
sample = preprocess(item) sample = preprocess(item)
inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
with torch.no_grad(): ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
ort_outs = ort_session.run(None, ort_inputs)
if torch.cuda.is_available(): d1 = ort_outs[0]
inputs_test = torch.cuda.FloatTensor( pred = d1[:, 0, :, :]
sample["image"].unsqueeze(0).cuda().float() predict = np.squeeze(norm_pred(pred))
) img = Image.fromarray(predict * 255).convert("RGB")
else:
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) return img
pred = d1[:, 0, :, :]
predict = norm_pred(pred)
predict = predict.squeeze()
predict_np = predict.cpu().detach().numpy()
img = Image.fromarray(predict_np * 255).convert("RGB")
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
return img

View File

@ -46,16 +46,7 @@ def index():
height = request.args.get("height", type=int) height = request.args.get("height", type=int)
model = request.values.get("model", type=str, default="u2net") model = request.values.get("model", type=str, default="u2net")
model_path = os.environ.get( model_choices = ["u2net", "u2netp", "u2net_human_seg"]
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [
os.path.splitext(os.path.basename(x))[0]
for x in set(glob.glob(model_path + "/*"))
]
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
if model not in model_choices: if model not in model_choices:
return { return {

View File

@ -1,541 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super().__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
return src
### RSU-7 ###
class RSU7(nn.Module): # UNet07DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module): # UNet06DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module): # UNet05DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module): # UNet04DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super().__init__()
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return (
torch.sigmoid(d0),
torch.sigmoid(d1),
torch.sigmoid(d2),
torch.sigmoid(d3),
torch.sigmoid(d4),
torch.sigmoid(d5),
torch.sigmoid(d6),
)
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super().__init__()
self.stage1 = RSU7(in_ch, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 16, 64)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(64, 16, 64)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(64, 16, 64)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(64, 16, 64)
# decoder
self.stage5d = RSU4F(128, 16, 64)
self.stage4d = RSU4(128, 16, 64)
self.stage3d = RSU5(128, 16, 64)
self.stage2d = RSU6(128, 16, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# decoder
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return (
torch.sigmoid(d0),
torch.sigmoid(d1),
torch.sigmoid(d2),
torch.sigmoid(d3),
torch.sigmoid(d4),
torch.sigmoid(d5),
torch.sigmoid(d6),
)

1
requirements-cpu.txt Normal file
View File

@ -0,0 +1 @@
onnxruntime==1.10.0

1
requirements-gpu.txt Normal file
View File

@ -0,0 +1 @@
onnxruntime-gpu==1.10.0

View File

@ -1,13 +1,10 @@
filetype==1.0.7
flask==1.1.2 flask==1.1.2
gdown==4.2.0
numpy==1.20.0 numpy==1.20.0
pillow==8.3.2 pillow==8.3.2
pymatting==1.1.5
scikit-image==0.19.1 scikit-image==0.19.1
torch==1.9.1
torchvision==0.10.1
waitress==1.4.4
tqdm==4.51.0
requests==2.24.0
scipy==1.5.4 scipy==1.5.4
pymatting==1.1.1 tqdm==4.51.0
filetype==1.0.7 waitress==1.4.4
matplotlib==3.5.1

View File

@ -14,6 +14,13 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
with open("requirements.txt") as f: with open("requirements.txt") as f:
requireds = f.read().splitlines() requireds = f.read().splitlines()
if os.getenv("GPU") is None:
with open("requirements-cpu.txt") as f:
requireds += f.read().splitlines()
else:
with open("requirements-gpu.txt") as f:
requireds += f.read().splitlines()
setup( setup(
name="rembg", name="rembg",
description="Remove image background", description="Remove image background",