This commit is contained in:
Daniel Gatis 2022-01-19 11:48:05 -03:00
parent 1ab362eb38
commit 3b18bad858
12 changed files with 113 additions and 1070 deletions

View File

@ -2,16 +2,16 @@ FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04
RUN apt-get update &&\
apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-dev \
build-essential
python3 \
python3-pip \
python3-dev \
build-essential
WORKDIR /rembg
COPY . .
RUN pip3 install .
RUN GPU=1 pip3 install .
# First run to compile AOT & download model
RUN rembg pixel.png >/dev/null

View File

@ -36,38 +36,19 @@ Rembg is a tool to remove images background. That is it.
#### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen ***
### Requirements
* python 3.8 or newer
* torch and torchvision stable version (https://pytorch.org)
#### How to install torch/torchvision
Go to https://pytorch.org and scrool down to `INSTALL PYTORCH` section and follow the instructions.
For example:
```
PyTorch Build: Stable (1.7.1)
Your OS: Windows
Package: Pip
Language: Python
CUDA: None
```
The install cmd is:
```
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
```
### Installation
Install it from pypi
CPU support:
```bash
pip install rembg
```
GPU support:
```bash
GPU=1 pip install rembg
```
### Usage as a cli
Remove the background from a remote image
@ -85,14 +66,6 @@ Remove the background from all images in a folder
rembg -p path/to/input path/to/output
```
### Add a custom model
Copy the `custom-model.pth` file to `~/.u2net` and run:
```bash
curl -s http://input.png | rembg -m custom-model > output.png
```
### Usage as a server
Start the server

View File

@ -1,4 +1,3 @@
import functools
import io
import numpy as np
@ -68,7 +67,6 @@ def naive_cutout(img, mask):
return cutout
@functools.lru_cache(maxsize=None)
def get_model(model_name):
if model_name == "u2netp":
return load_model(model_name="u2netp")

View File

@ -10,17 +10,6 @@ from .bg import remove
def main():
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [
os.path.splitext(os.path.basename(x))[0]
for x in set(glob.glob(model_path + "/*"))
]
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
ap = argparse.ArgumentParser()
ap.add_argument(
@ -28,7 +17,7 @@ def main():
"--model",
default="u2net",
type=str,
choices=model_choices,
choices=["u2net", "u2netp", "u2net_human_seg"],
help="The model name.",
)

View File

@ -1,326 +0,0 @@
# data loader
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from skimage import color, io, transform
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils
# ==========================dataset load==========================
class RescaleT:
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(
image, (self.output_size, self.output_size), mode="constant"
)
lbl = transform.resize(
label,
(self.output_size, self.output_size),
mode="constant",
order=0,
preserve_range=True,
)
return {"imidx": imidx, "image": img, "label": lbl}
class Rescale:
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image, (new_h, new_w), mode="constant")
lbl = transform.resize(
label, (new_h, new_w), mode="constant", order=0, preserve_range=True
)
return {"imidx": imidx, "image": img, "label": lbl}
class RandomCrop:
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top : top + new_h, left : left + new_w]
label = label[top : top + new_h, left : left + new_w]
return {"imidx": imidx, "image": image, "label": label}
class ToTensor:
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
tmpLbl = np.zeros(label.shape)
image = image / np.max(image)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {
"imidx": torch.from_numpy(imidx),
"image": torch.from_numpy(tmpImg),
"label": torch.from_numpy(tmpLbl),
}
class ToTensorLab:
"""Convert ndarrays in sample to Tensors."""
def __init__(self, flag=0):
self.flag = flag
def __call__(self, sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpLbl = np.zeros(label.shape)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImgt[:, :, 0] = image[:, :, 0]
tmpImgt[:, :, 1] = image[:, :, 0]
tmpImgt[:, :, 2] = image[:, :, 0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
)
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
)
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
)
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
)
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
)
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
tmpImg[:, :, 0]
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
tmpImg[:, :, 1]
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
tmpImg[:, :, 2]
)
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
tmpImg[:, :, 3]
)
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
tmpImg[:, :, 4]
)
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
tmpImg[:, :, 5]
)
elif self.flag == 1: # with Lab color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImg[:, :, 0] = image[:, :, 0]
tmpImg[:, :, 1] = image[:, :, 0]
tmpImg[:, :, 2] = image[:, :, 0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
)
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
tmpImg[:, :, 0]
)
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
tmpImg[:, :, 1]
)
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
tmpImg[:, :, 2]
)
else: # with rgb color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {
"imidx": torch.from_numpy(imidx),
"image": torch.from_numpy(tmpImg),
"label": torch.from_numpy(tmpLbl),
}
class SalObjDataset(Dataset):
def __init__(self, img_name_list, lbl_name_list, transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self, idx):
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
if 0 == len(self.label_name_list):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
sample = {"imidx": imidx, "image": image, "label": label}
if self.transform:
sample = self.transform(sample)
return sample

View File

@ -1,139 +1,103 @@
import errno
import os
import sys
import urllib.request
from hashlib import md5
import gdown
import numpy as np
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import onnxruntime as ort
from PIL import Image
from skimage import transform
from torchvision import transforms
from tqdm import tqdm
from .data_loader import RescaleT, ToTensorLab
from .u2net import U2NET, U2NETP
def download_file_from_google_drive(id, fname, destination):
head, tail = os.path.split(destination)
os.makedirs(head, exist_ok=True)
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={"id": id}, stream=True)
token = None
for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
if token:
params = {"id": id, "confirm": token}
response = session.get(URL, params=params, stream=True)
total = int(response.headers.get("content-length", 0))
with open(destination, "wb") as file, tqdm(
desc=f"Downloading {tail} to {head}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
SESSIONS = {}
def load_model(model_name: str = "u2net"):
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
)
if model_name == "u2netp":
net = U2NETP(3, 1)
path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e"
):
download_file_from_google_drive(
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
"u2netp.pth",
path,
)
md5 = "8e83ca70e441ab06c318d82300c84806"
url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
elif model_name == "u2net":
net = U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55"
):
download_file_from_google_drive(
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
"pth",
path,
)
md5 = "60024c5c889badc19c04ad937298a77b"
url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
elif model_name == "u2net_human_seg":
net = U2NET(3, 1)
path = os.environ.get(
"U2NET_PATH",
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
)
if (
not os.path.exists(path)
or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a"
):
download_file_from_google_drive(
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
"u2net_human_seg.pth",
path,
)
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
else:
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
try:
if torch.cuda.is_available():
net.load_state_dict(torch.load(path))
net.to(torch.device("cuda"))
else:
net.load_state_dict(
torch.load(
path,
map_location="cpu",
)
)
except FileNotFoundError:
raise FileNotFoundError(
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
)
if SESSIONS.get(md5) is None:
gdown.cached_download(url, path, md5=md5, quiet=False)
SESSIONS[md5] = ort.InferenceSession(path)
net.eval()
return net
return SESSIONS[md5]
def norm_pred(d):
ma = torch.max(d)
mi = torch.min(d)
ma = np.max(d)
mi = np.min(d)
dn = (d - mi) / (ma - mi)
return dn
def rescale(sample, output_size):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
h, w = image.shape[:2]
if isinstance(output_size, int):
if h > w:
new_h, new_w = output_size * h / w, output_size
else:
new_h, new_w = output_size, output_size * w / h
else:
new_h, new_w = output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (output_size, output_size), mode="constant")
lbl = transform.resize(
label,
(output_size, output_size),
mode="constant",
order=0,
preserve_range=True,
)
return {"imidx": imidx, "image": img, "label": lbl}
def color(sample):
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
tmpLbl = np.zeros(label.shape)
if np.max(label) < 1e-6:
label = label
else:
label = label / np.max(label)
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpLbl[:, :, 0] = label[:, :, 0]
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
@ -149,34 +113,23 @@ def preprocess(image):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
sample = {"imidx": np.array([0]), "image": image, "label": label}
sample = rescale(sample, 320)
sample = color(sample)
return sample
def predict(net, item):
def predict(ort_session, item):
sample = preprocess(item)
inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
with torch.no_grad():
ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
ort_outs = ort_session.run(None, ort_inputs)
if torch.cuda.is_available():
inputs_test = torch.cuda.FloatTensor(
sample["image"].unsqueeze(0).cuda().float()
)
else:
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
d1 = ort_outs[0]
pred = d1[:, 0, :, :]
predict = np.squeeze(norm_pred(pred))
img = Image.fromarray(predict * 255).convert("RGB")
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
pred = d1[:, 0, :, :]
predict = norm_pred(pred)
predict = predict.squeeze()
predict_np = predict.cpu().detach().numpy()
img = Image.fromarray(predict_np * 255).convert("RGB")
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
return img
return img

View File

@ -46,16 +46,7 @@ def index():
height = request.args.get("height", type=int)
model = request.values.get("model", type=str, default="u2net")
model_path = os.environ.get(
"U2NETP_PATH",
os.path.expanduser(os.path.join("~", ".u2net")),
)
model_choices = [
os.path.splitext(os.path.basename(x))[0]
for x in set(glob.glob(model_path + "/*"))
]
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
if model not in model_choices:
return {

View File

@ -1,541 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super().__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
return src
### RSU-7 ###
class RSU7(nn.Module): # UNet07DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module): # UNet06DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module): # UNet05DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module): # UNet04DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super().__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super().__init__()
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return (
torch.sigmoid(d0),
torch.sigmoid(d1),
torch.sigmoid(d2),
torch.sigmoid(d3),
torch.sigmoid(d4),
torch.sigmoid(d5),
torch.sigmoid(d6),
)
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super().__init__()
self.stage1 = RSU7(in_ch, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 16, 64)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(64, 16, 64)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(64, 16, 64)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(64, 16, 64)
# decoder
self.stage5d = RSU4F(128, 16, 64)
self.stage4d = RSU4(128, 16, 64)
self.stage3d = RSU5(128, 16, 64)
self.stage2d = RSU6(128, 16, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# decoder
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return (
torch.sigmoid(d0),
torch.sigmoid(d1),
torch.sigmoid(d2),
torch.sigmoid(d3),
torch.sigmoid(d4),
torch.sigmoid(d5),
torch.sigmoid(d6),
)

1
requirements-cpu.txt Normal file
View File

@ -0,0 +1 @@
onnxruntime==1.10.0

1
requirements-gpu.txt Normal file
View File

@ -0,0 +1 @@
onnxruntime-gpu==1.10.0

View File

@ -1,13 +1,10 @@
filetype==1.0.7
flask==1.1.2
gdown==4.2.0
numpy==1.20.0
pillow==8.3.2
pymatting==1.1.5
scikit-image==0.19.1
torch==1.9.1
torchvision==0.10.1
waitress==1.4.4
tqdm==4.51.0
requests==2.24.0
scipy==1.5.4
pymatting==1.1.1
filetype==1.0.7
matplotlib==3.5.1
tqdm==4.51.0
waitress==1.4.4

View File

@ -14,6 +14,13 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
with open("requirements.txt") as f:
requireds = f.read().splitlines()
if os.getenv("GPU") is None:
with open("requirements-cpu.txt") as f:
requireds += f.read().splitlines()
else:
with open("requirements-gpu.txt") as f:
requireds += f.read().splitlines()
setup(
name="rembg",
description="Remove image background",