mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-20 21:39:14 +08:00
add onnx
This commit is contained in:
parent
1ab362eb38
commit
3b18bad858
10
Dockerfile
10
Dockerfile
@ -2,16 +2,16 @@ FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04
|
|||||||
|
|
||||||
RUN apt-get update &&\
|
RUN apt-get update &&\
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
python3 \
|
python3 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
python3-dev \
|
python3-dev \
|
||||||
build-essential
|
build-essential
|
||||||
|
|
||||||
WORKDIR /rembg
|
WORKDIR /rembg
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
RUN pip3 install .
|
RUN GPU=1 pip3 install .
|
||||||
|
|
||||||
# First run to compile AOT & download model
|
# First run to compile AOT & download model
|
||||||
RUN rembg pixel.png >/dev/null
|
RUN rembg pixel.png >/dev/null
|
||||||
|
39
README.md
39
README.md
@ -36,38 +36,19 @@ Rembg is a tool to remove images background. That is it.
|
|||||||
|
|
||||||
#### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen ***
|
#### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen ***
|
||||||
|
|
||||||
### Requirements
|
|
||||||
|
|
||||||
* python 3.8 or newer
|
|
||||||
|
|
||||||
* torch and torchvision stable version (https://pytorch.org)
|
|
||||||
|
|
||||||
#### How to install torch/torchvision
|
|
||||||
|
|
||||||
Go to https://pytorch.org and scrool down to `INSTALL PYTORCH` section and follow the instructions.
|
|
||||||
|
|
||||||
For example:
|
|
||||||
```
|
|
||||||
PyTorch Build: Stable (1.7.1)
|
|
||||||
Your OS: Windows
|
|
||||||
Package: Pip
|
|
||||||
Language: Python
|
|
||||||
CUDA: None
|
|
||||||
```
|
|
||||||
|
|
||||||
The install cmd is:
|
|
||||||
```
|
|
||||||
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
|
||||||
```
|
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
Install it from pypi
|
CPU support:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install rembg
|
pip install rembg
|
||||||
```
|
```
|
||||||
|
|
||||||
|
GPU support:
|
||||||
|
```bash
|
||||||
|
GPU=1 pip install rembg
|
||||||
|
```
|
||||||
|
|
||||||
### Usage as a cli
|
### Usage as a cli
|
||||||
|
|
||||||
Remove the background from a remote image
|
Remove the background from a remote image
|
||||||
@ -85,14 +66,6 @@ Remove the background from all images in a folder
|
|||||||
rembg -p path/to/input path/to/output
|
rembg -p path/to/input path/to/output
|
||||||
```
|
```
|
||||||
|
|
||||||
### Add a custom model
|
|
||||||
|
|
||||||
Copy the `custom-model.pth` file to `~/.u2net` and run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl -s http://input.png | rembg -m custom-model > output.png
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage as a server
|
### Usage as a server
|
||||||
|
|
||||||
Start the server
|
Start the server
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import functools
|
|
||||||
import io
|
import io
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -68,7 +67,6 @@ def naive_cutout(img, mask):
|
|||||||
return cutout
|
return cutout
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def get_model(model_name):
|
def get_model(model_name):
|
||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
return load_model(model_name="u2netp")
|
return load_model(model_name="u2netp")
|
||||||
|
13
rembg/cli.py
13
rembg/cli.py
@ -10,17 +10,6 @@ from .bg import remove
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_path = os.environ.get(
|
|
||||||
"U2NETP_PATH",
|
|
||||||
os.path.expanduser(os.path.join("~", ".u2net")),
|
|
||||||
)
|
|
||||||
model_choices = [
|
|
||||||
os.path.splitext(os.path.basename(x))[0]
|
|
||||||
for x in set(glob.glob(model_path + "/*"))
|
|
||||||
]
|
|
||||||
|
|
||||||
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
|
|
||||||
|
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
|
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
@ -28,7 +17,7 @@ def main():
|
|||||||
"--model",
|
"--model",
|
||||||
default="u2net",
|
default="u2net",
|
||||||
type=str,
|
type=str,
|
||||||
choices=model_choices,
|
choices=["u2net", "u2netp", "u2net_human_seg"],
|
||||||
help="The model name.",
|
help="The model name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,326 +0,0 @@
|
|||||||
# data loader
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from skimage import color, io, transform
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
from torchvision import transforms, utils
|
|
||||||
|
|
||||||
|
|
||||||
# ==========================dataset load==========================
|
|
||||||
class RescaleT:
|
|
||||||
def __init__(self, output_size):
|
|
||||||
assert isinstance(output_size, (int, tuple))
|
|
||||||
self.output_size = output_size
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
|
||||||
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
if isinstance(self.output_size, int):
|
|
||||||
if h > w:
|
|
||||||
new_h, new_w = self.output_size * h / w, self.output_size
|
|
||||||
else:
|
|
||||||
new_h, new_w = self.output_size, self.output_size * w / h
|
|
||||||
else:
|
|
||||||
new_h, new_w = self.output_size
|
|
||||||
|
|
||||||
new_h, new_w = int(new_h), int(new_w)
|
|
||||||
|
|
||||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
|
||||||
# img = transform.resize(image,(new_h,new_w),mode='constant')
|
|
||||||
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
|
|
||||||
|
|
||||||
img = transform.resize(
|
|
||||||
image, (self.output_size, self.output_size), mode="constant"
|
|
||||||
)
|
|
||||||
lbl = transform.resize(
|
|
||||||
label,
|
|
||||||
(self.output_size, self.output_size),
|
|
||||||
mode="constant",
|
|
||||||
order=0,
|
|
||||||
preserve_range=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"imidx": imidx, "image": img, "label": lbl}
|
|
||||||
|
|
||||||
|
|
||||||
class Rescale:
|
|
||||||
def __init__(self, output_size):
|
|
||||||
assert isinstance(output_size, (int, tuple))
|
|
||||||
self.output_size = output_size
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
|
||||||
|
|
||||||
if random.random() >= 0.5:
|
|
||||||
image = image[::-1]
|
|
||||||
label = label[::-1]
|
|
||||||
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
|
|
||||||
if isinstance(self.output_size, int):
|
|
||||||
if h > w:
|
|
||||||
new_h, new_w = self.output_size * h / w, self.output_size
|
|
||||||
else:
|
|
||||||
new_h, new_w = self.output_size, self.output_size * w / h
|
|
||||||
else:
|
|
||||||
new_h, new_w = self.output_size
|
|
||||||
|
|
||||||
new_h, new_w = int(new_h), int(new_w)
|
|
||||||
|
|
||||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
|
||||||
img = transform.resize(image, (new_h, new_w), mode="constant")
|
|
||||||
lbl = transform.resize(
|
|
||||||
label, (new_h, new_w), mode="constant", order=0, preserve_range=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"imidx": imidx, "image": img, "label": lbl}
|
|
||||||
|
|
||||||
|
|
||||||
class RandomCrop:
|
|
||||||
def __init__(self, output_size):
|
|
||||||
assert isinstance(output_size, (int, tuple))
|
|
||||||
if isinstance(output_size, int):
|
|
||||||
self.output_size = (output_size, output_size)
|
|
||||||
else:
|
|
||||||
assert len(output_size) == 2
|
|
||||||
self.output_size = output_size
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
|
||||||
|
|
||||||
if random.random() >= 0.5:
|
|
||||||
image = image[::-1]
|
|
||||||
label = label[::-1]
|
|
||||||
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
new_h, new_w = self.output_size
|
|
||||||
|
|
||||||
top = np.random.randint(0, h - new_h)
|
|
||||||
left = np.random.randint(0, w - new_w)
|
|
||||||
|
|
||||||
image = image[top : top + new_h, left : left + new_w]
|
|
||||||
label = label[top : top + new_h, left : left + new_w]
|
|
||||||
|
|
||||||
return {"imidx": imidx, "image": image, "label": label}
|
|
||||||
|
|
||||||
|
|
||||||
class ToTensor:
|
|
||||||
"""Convert ndarrays in sample to Tensors."""
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
|
|
||||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
|
||||||
|
|
||||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
|
||||||
tmpLbl = np.zeros(label.shape)
|
|
||||||
|
|
||||||
image = image / np.max(image)
|
|
||||||
if np.max(label) < 1e-6:
|
|
||||||
label = label
|
|
||||||
else:
|
|
||||||
label = label / np.max(label)
|
|
||||||
|
|
||||||
if image.shape[2] == 1:
|
|
||||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
else:
|
|
||||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
|
||||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
|
||||||
|
|
||||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
|
||||||
|
|
||||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
|
||||||
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
|
||||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
|
||||||
tmpLbl = label.transpose((2, 0, 1))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"imidx": torch.from_numpy(imidx),
|
|
||||||
"image": torch.from_numpy(tmpImg),
|
|
||||||
"label": torch.from_numpy(tmpLbl),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ToTensorLab:
|
|
||||||
"""Convert ndarrays in sample to Tensors."""
|
|
||||||
|
|
||||||
def __init__(self, flag=0):
|
|
||||||
self.flag = flag
|
|
||||||
|
|
||||||
def __call__(self, sample):
|
|
||||||
|
|
||||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
|
||||||
|
|
||||||
tmpLbl = np.zeros(label.shape)
|
|
||||||
|
|
||||||
if np.max(label) < 1e-6:
|
|
||||||
label = label
|
|
||||||
else:
|
|
||||||
label = label / np.max(label)
|
|
||||||
|
|
||||||
# change the color space
|
|
||||||
if self.flag == 2: # with rgb and Lab colors
|
|
||||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
|
|
||||||
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
|
|
||||||
if image.shape[2] == 1:
|
|
||||||
tmpImgt[:, :, 0] = image[:, :, 0]
|
|
||||||
tmpImgt[:, :, 1] = image[:, :, 0]
|
|
||||||
tmpImgt[:, :, 2] = image[:, :, 0]
|
|
||||||
else:
|
|
||||||
tmpImgt = image
|
|
||||||
tmpImgtl = color.rgb2lab(tmpImgt)
|
|
||||||
|
|
||||||
# nomalize image to range [0,1]
|
|
||||||
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
|
|
||||||
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
|
|
||||||
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
|
|
||||||
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
|
|
||||||
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
|
|
||||||
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
|
|
||||||
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
|
|
||||||
)
|
|
||||||
|
|
||||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
|
||||||
|
|
||||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
|
|
||||||
tmpImg[:, :, 0]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
|
|
||||||
tmpImg[:, :, 1]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
|
|
||||||
tmpImg[:, :, 2]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
|
|
||||||
tmpImg[:, :, 3]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
|
|
||||||
tmpImg[:, :, 4]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
|
|
||||||
tmpImg[:, :, 5]
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.flag == 1: # with Lab color
|
|
||||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
|
||||||
|
|
||||||
if image.shape[2] == 1:
|
|
||||||
tmpImg[:, :, 0] = image[:, :, 0]
|
|
||||||
tmpImg[:, :, 1] = image[:, :, 0]
|
|
||||||
tmpImg[:, :, 2] = image[:, :, 0]
|
|
||||||
else:
|
|
||||||
tmpImg = image
|
|
||||||
|
|
||||||
tmpImg = color.rgb2lab(tmpImg)
|
|
||||||
|
|
||||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
|
||||||
|
|
||||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
|
|
||||||
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
|
|
||||||
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
|
|
||||||
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
|
|
||||||
)
|
|
||||||
|
|
||||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
|
|
||||||
tmpImg[:, :, 0]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
|
|
||||||
tmpImg[:, :, 1]
|
|
||||||
)
|
|
||||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
|
|
||||||
tmpImg[:, :, 2]
|
|
||||||
)
|
|
||||||
|
|
||||||
else: # with rgb color
|
|
||||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
|
||||||
image = image / np.max(image)
|
|
||||||
if image.shape[2] == 1:
|
|
||||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
else:
|
|
||||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
|
||||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
|
||||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
|
||||||
|
|
||||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
|
||||||
|
|
||||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
|
||||||
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
|
||||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
|
||||||
tmpLbl = label.transpose((2, 0, 1))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"imidx": torch.from_numpy(imidx),
|
|
||||||
"image": torch.from_numpy(tmpImg),
|
|
||||||
"label": torch.from_numpy(tmpLbl),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SalObjDataset(Dataset):
|
|
||||||
def __init__(self, img_name_list, lbl_name_list, transform=None):
|
|
||||||
# self.root_dir = root_dir
|
|
||||||
# self.image_name_list = glob.glob(image_dir+'*.png')
|
|
||||||
# self.label_name_list = glob.glob(label_dir+'*.png')
|
|
||||||
self.image_name_list = img_name_list
|
|
||||||
self.label_name_list = lbl_name_list
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.image_name_list)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
|
|
||||||
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
|
|
||||||
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
|
|
||||||
|
|
||||||
image = io.imread(self.image_name_list[idx])
|
|
||||||
imname = self.image_name_list[idx]
|
|
||||||
imidx = np.array([idx])
|
|
||||||
|
|
||||||
if 0 == len(self.label_name_list):
|
|
||||||
label_3 = np.zeros(image.shape)
|
|
||||||
else:
|
|
||||||
label_3 = io.imread(self.label_name_list[idx])
|
|
||||||
|
|
||||||
label = np.zeros(label_3.shape[0:2])
|
|
||||||
if 3 == len(label_3.shape):
|
|
||||||
label = label_3[:, :, 0]
|
|
||||||
elif 2 == len(label_3.shape):
|
|
||||||
label = label_3
|
|
||||||
|
|
||||||
if 3 == len(image.shape) and 2 == len(label.shape):
|
|
||||||
label = label[:, :, np.newaxis]
|
|
||||||
elif 2 == len(image.shape) and 2 == len(label.shape):
|
|
||||||
image = image[:, :, np.newaxis]
|
|
||||||
label = label[:, :, np.newaxis]
|
|
||||||
|
|
||||||
sample = {"imidx": imidx, "image": image, "label": label}
|
|
||||||
|
|
||||||
if self.transform:
|
|
||||||
sample = self.transform(sample)
|
|
||||||
|
|
||||||
return sample
|
|
219
rembg/detect.py
219
rembg/detect.py
@ -1,139 +1,103 @@
|
|||||||
import errno
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import urllib.request
|
|
||||||
from hashlib import md5
|
|
||||||
|
|
||||||
|
import gdown
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import onnxruntime as ort
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from skimage import transform
|
from skimage import transform
|
||||||
from torchvision import transforms
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .data_loader import RescaleT, ToTensorLab
|
SESSIONS = {}
|
||||||
from .u2net import U2NET, U2NETP
|
|
||||||
|
|
||||||
|
|
||||||
def download_file_from_google_drive(id, fname, destination):
|
|
||||||
head, tail = os.path.split(destination)
|
|
||||||
os.makedirs(head, exist_ok=True)
|
|
||||||
|
|
||||||
URL = "https://docs.google.com/uc?export=download"
|
|
||||||
|
|
||||||
session = requests.Session()
|
|
||||||
response = session.get(URL, params={"id": id}, stream=True)
|
|
||||||
|
|
||||||
token = None
|
|
||||||
for key, value in response.cookies.items():
|
|
||||||
if key.startswith("download_warning"):
|
|
||||||
token = value
|
|
||||||
break
|
|
||||||
|
|
||||||
if token:
|
|
||||||
params = {"id": id, "confirm": token}
|
|
||||||
response = session.get(URL, params=params, stream=True)
|
|
||||||
|
|
||||||
total = int(response.headers.get("content-length", 0))
|
|
||||||
|
|
||||||
with open(destination, "wb") as file, tqdm(
|
|
||||||
desc=f"Downloading {tail} to {head}",
|
|
||||||
total=total,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as bar:
|
|
||||||
for data in response.iter_content(chunk_size=1024):
|
|
||||||
size = file.write(data)
|
|
||||||
bar.update(size)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str = "u2net"):
|
def load_model(model_name: str = "u2net"):
|
||||||
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
|
path = os.environ.get(
|
||||||
|
"U2NETP_PATH",
|
||||||
|
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
|
||||||
|
)
|
||||||
|
|
||||||
if model_name == "u2netp":
|
if model_name == "u2netp":
|
||||||
net = U2NETP(3, 1)
|
md5 = "8e83ca70e441ab06c318d82300c84806"
|
||||||
path = os.environ.get(
|
url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
|
||||||
"U2NETP_PATH",
|
|
||||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not os.path.exists(path)
|
|
||||||
or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
|
||||||
):
|
|
||||||
download_file_from_google_drive(
|
|
||||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
|
||||||
"u2netp.pth",
|
|
||||||
path,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_name == "u2net":
|
elif model_name == "u2net":
|
||||||
net = U2NET(3, 1)
|
md5 = "60024c5c889badc19c04ad937298a77b"
|
||||||
path = os.environ.get(
|
url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
|
||||||
"U2NET_PATH",
|
|
||||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not os.path.exists(path)
|
|
||||||
or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
|
||||||
):
|
|
||||||
download_file_from_google_drive(
|
|
||||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
|
||||||
"pth",
|
|
||||||
path,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_name == "u2net_human_seg":
|
elif model_name == "u2net_human_seg":
|
||||||
net = U2NET(3, 1)
|
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
||||||
path = os.environ.get(
|
url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
|
||||||
"U2NET_PATH",
|
|
||||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not os.path.exists(path)
|
|
||||||
or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a"
|
|
||||||
):
|
|
||||||
download_file_from_google_drive(
|
|
||||||
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|
|
||||||
"u2net_human_seg.pth",
|
|
||||||
path,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
||||||
|
|
||||||
try:
|
if SESSIONS.get(md5) is None:
|
||||||
if torch.cuda.is_available():
|
gdown.cached_download(url, path, md5=md5, quiet=False)
|
||||||
net.load_state_dict(torch.load(path))
|
SESSIONS[md5] = ort.InferenceSession(path)
|
||||||
net.to(torch.device("cuda"))
|
|
||||||
else:
|
|
||||||
net.load_state_dict(
|
|
||||||
torch.load(
|
|
||||||
path,
|
|
||||||
map_location="cpu",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
net.eval()
|
return SESSIONS[md5]
|
||||||
|
|
||||||
return net
|
|
||||||
|
|
||||||
|
|
||||||
def norm_pred(d):
|
def norm_pred(d):
|
||||||
ma = torch.max(d)
|
ma = np.max(d)
|
||||||
mi = torch.min(d)
|
mi = np.min(d)
|
||||||
dn = (d - mi) / (ma - mi)
|
dn = (d - mi) / (ma - mi)
|
||||||
|
|
||||||
return dn
|
return dn
|
||||||
|
|
||||||
|
|
||||||
|
def rescale(sample, output_size):
|
||||||
|
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
if isinstance(output_size, int):
|
||||||
|
if h > w:
|
||||||
|
new_h, new_w = output_size * h / w, output_size
|
||||||
|
else:
|
||||||
|
new_h, new_w = output_size, output_size * w / h
|
||||||
|
else:
|
||||||
|
new_h, new_w = output_size
|
||||||
|
|
||||||
|
new_h, new_w = int(new_h), int(new_w)
|
||||||
|
|
||||||
|
img = transform.resize(image, (output_size, output_size), mode="constant")
|
||||||
|
lbl = transform.resize(
|
||||||
|
label,
|
||||||
|
(output_size, output_size),
|
||||||
|
mode="constant",
|
||||||
|
order=0,
|
||||||
|
preserve_range=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"imidx": imidx, "image": img, "label": lbl}
|
||||||
|
|
||||||
|
|
||||||
|
def color(sample):
|
||||||
|
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||||
|
|
||||||
|
tmpLbl = np.zeros(label.shape)
|
||||||
|
|
||||||
|
if np.max(label) < 1e-6:
|
||||||
|
label = label
|
||||||
|
else:
|
||||||
|
label = label / np.max(label)
|
||||||
|
|
||||||
|
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||||
|
image = image / np.max(image)
|
||||||
|
if image.shape[2] == 1:
|
||||||
|
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||||
|
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
||||||
|
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
||||||
|
else:
|
||||||
|
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||||
|
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
||||||
|
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
||||||
|
|
||||||
|
tmpLbl[:, :, 0] = label[:, :, 0]
|
||||||
|
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||||
|
tmpLbl = label.transpose((2, 0, 1))
|
||||||
|
|
||||||
|
return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
|
||||||
|
|
||||||
|
|
||||||
def preprocess(image):
|
def preprocess(image):
|
||||||
label_3 = np.zeros(image.shape)
|
label_3 = np.zeros(image.shape)
|
||||||
label = np.zeros(label_3.shape[0:2])
|
label = np.zeros(label_3.shape[0:2])
|
||||||
@ -149,34 +113,23 @@ def preprocess(image):
|
|||||||
image = image[:, :, np.newaxis]
|
image = image[:, :, np.newaxis]
|
||||||
label = label[:, :, np.newaxis]
|
label = label[:, :, np.newaxis]
|
||||||
|
|
||||||
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
|
sample = {"imidx": np.array([0]), "image": image, "label": label}
|
||||||
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
|
sample = rescale(sample, 320)
|
||||||
|
sample = color(sample)
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def predict(net, item):
|
def predict(ort_session, item):
|
||||||
|
|
||||||
sample = preprocess(item)
|
sample = preprocess(item)
|
||||||
|
inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
|
||||||
|
|
||||||
with torch.no_grad():
|
ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
|
||||||
|
ort_outs = ort_session.run(None, ort_inputs)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
d1 = ort_outs[0]
|
||||||
inputs_test = torch.cuda.FloatTensor(
|
pred = d1[:, 0, :, :]
|
||||||
sample["image"].unsqueeze(0).cuda().float()
|
predict = np.squeeze(norm_pred(pred))
|
||||||
)
|
img = Image.fromarray(predict * 255).convert("RGB")
|
||||||
else:
|
|
||||||
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
|
|
||||||
|
|
||||||
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
|
return img
|
||||||
|
|
||||||
pred = d1[:, 0, :, :]
|
|
||||||
predict = norm_pred(pred)
|
|
||||||
|
|
||||||
predict = predict.squeeze()
|
|
||||||
predict_np = predict.cpu().detach().numpy()
|
|
||||||
img = Image.fromarray(predict_np * 255).convert("RGB")
|
|
||||||
|
|
||||||
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
@ -46,16 +46,7 @@ def index():
|
|||||||
height = request.args.get("height", type=int)
|
height = request.args.get("height", type=int)
|
||||||
|
|
||||||
model = request.values.get("model", type=str, default="u2net")
|
model = request.values.get("model", type=str, default="u2net")
|
||||||
model_path = os.environ.get(
|
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
|
||||||
"U2NETP_PATH",
|
|
||||||
os.path.expanduser(os.path.join("~", ".u2net")),
|
|
||||||
)
|
|
||||||
model_choices = [
|
|
||||||
os.path.splitext(os.path.basename(x))[0]
|
|
||||||
for x in set(glob.glob(model_path + "/*"))
|
|
||||||
]
|
|
||||||
|
|
||||||
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
|
|
||||||
|
|
||||||
if model not in model_choices:
|
if model not in model_choices:
|
||||||
return {
|
return {
|
||||||
|
541
rembg/u2net.py
541
rembg/u2net.py
@ -1,541 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torchvision import models
|
|
||||||
|
|
||||||
|
|
||||||
class REBNCONV(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv_s1 = nn.Conv2d(
|
|
||||||
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
|
||||||
)
|
|
||||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
|
||||||
self.relu_s1 = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
|
||||||
|
|
||||||
return xout
|
|
||||||
|
|
||||||
|
|
||||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
|
||||||
def _upsample_like(src, tar):
|
|
||||||
|
|
||||||
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
|
||||||
|
|
||||||
return src
|
|
||||||
|
|
||||||
|
|
||||||
### RSU-7 ###
|
|
||||||
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
||||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
||||||
|
|
||||||
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
hxin = self.rebnconvin(hx)
|
|
||||||
|
|
||||||
hx1 = self.rebnconv1(hxin)
|
|
||||||
hx = self.pool1(hx1)
|
|
||||||
|
|
||||||
hx2 = self.rebnconv2(hx)
|
|
||||||
hx = self.pool2(hx2)
|
|
||||||
|
|
||||||
hx3 = self.rebnconv3(hx)
|
|
||||||
hx = self.pool3(hx3)
|
|
||||||
|
|
||||||
hx4 = self.rebnconv4(hx)
|
|
||||||
hx = self.pool4(hx4)
|
|
||||||
|
|
||||||
hx5 = self.rebnconv5(hx)
|
|
||||||
hx = self.pool5(hx5)
|
|
||||||
|
|
||||||
hx6 = self.rebnconv6(hx)
|
|
||||||
|
|
||||||
hx7 = self.rebnconv7(hx6)
|
|
||||||
|
|
||||||
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
|
||||||
hx6dup = _upsample_like(hx6d, hx5)
|
|
||||||
|
|
||||||
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
|
||||||
hx5dup = _upsample_like(hx5d, hx4)
|
|
||||||
|
|
||||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
|
||||||
hx4dup = _upsample_like(hx4d, hx3)
|
|
||||||
|
|
||||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
|
||||||
hx3dup = _upsample_like(hx3d, hx2)
|
|
||||||
|
|
||||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
||||||
hx2dup = _upsample_like(hx2d, hx1)
|
|
||||||
|
|
||||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
||||||
|
|
||||||
return hx1d + hxin
|
|
||||||
|
|
||||||
|
|
||||||
### RSU-6 ###
|
|
||||||
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
||||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
||||||
|
|
||||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
|
|
||||||
hxin = self.rebnconvin(hx)
|
|
||||||
|
|
||||||
hx1 = self.rebnconv1(hxin)
|
|
||||||
hx = self.pool1(hx1)
|
|
||||||
|
|
||||||
hx2 = self.rebnconv2(hx)
|
|
||||||
hx = self.pool2(hx2)
|
|
||||||
|
|
||||||
hx3 = self.rebnconv3(hx)
|
|
||||||
hx = self.pool3(hx3)
|
|
||||||
|
|
||||||
hx4 = self.rebnconv4(hx)
|
|
||||||
hx = self.pool4(hx4)
|
|
||||||
|
|
||||||
hx5 = self.rebnconv5(hx)
|
|
||||||
|
|
||||||
hx6 = self.rebnconv6(hx5)
|
|
||||||
|
|
||||||
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
|
||||||
hx5dup = _upsample_like(hx5d, hx4)
|
|
||||||
|
|
||||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
|
||||||
hx4dup = _upsample_like(hx4d, hx3)
|
|
||||||
|
|
||||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
|
||||||
hx3dup = _upsample_like(hx3d, hx2)
|
|
||||||
|
|
||||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
||||||
hx2dup = _upsample_like(hx2d, hx1)
|
|
||||||
|
|
||||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
||||||
|
|
||||||
return hx1d + hxin
|
|
||||||
|
|
||||||
|
|
||||||
### RSU-5 ###
|
|
||||||
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
||||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
||||||
|
|
||||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
|
|
||||||
hxin = self.rebnconvin(hx)
|
|
||||||
|
|
||||||
hx1 = self.rebnconv1(hxin)
|
|
||||||
hx = self.pool1(hx1)
|
|
||||||
|
|
||||||
hx2 = self.rebnconv2(hx)
|
|
||||||
hx = self.pool2(hx2)
|
|
||||||
|
|
||||||
hx3 = self.rebnconv3(hx)
|
|
||||||
hx = self.pool3(hx3)
|
|
||||||
|
|
||||||
hx4 = self.rebnconv4(hx)
|
|
||||||
|
|
||||||
hx5 = self.rebnconv5(hx4)
|
|
||||||
|
|
||||||
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
|
||||||
hx4dup = _upsample_like(hx4d, hx3)
|
|
||||||
|
|
||||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
|
||||||
hx3dup = _upsample_like(hx3d, hx2)
|
|
||||||
|
|
||||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
||||||
hx2dup = _upsample_like(hx2d, hx1)
|
|
||||||
|
|
||||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
||||||
|
|
||||||
return hx1d + hxin
|
|
||||||
|
|
||||||
|
|
||||||
### RSU-4 ###
|
|
||||||
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
||||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
||||||
|
|
||||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
|
||||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
|
|
||||||
hxin = self.rebnconvin(hx)
|
|
||||||
|
|
||||||
hx1 = self.rebnconv1(hxin)
|
|
||||||
hx = self.pool1(hx1)
|
|
||||||
|
|
||||||
hx2 = self.rebnconv2(hx)
|
|
||||||
hx = self.pool2(hx2)
|
|
||||||
|
|
||||||
hx3 = self.rebnconv3(hx)
|
|
||||||
|
|
||||||
hx4 = self.rebnconv4(hx3)
|
|
||||||
|
|
||||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
|
||||||
hx3dup = _upsample_like(hx3d, hx2)
|
|
||||||
|
|
||||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
|
||||||
hx2dup = _upsample_like(hx2d, hx1)
|
|
||||||
|
|
||||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
|
||||||
|
|
||||||
return hx1d + hxin
|
|
||||||
|
|
||||||
|
|
||||||
### RSU-4F ###
|
|
||||||
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
|
||||||
|
|
||||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
|
||||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
|
||||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
|
||||||
|
|
||||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
|
||||||
|
|
||||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
|
||||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
|
||||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
|
|
||||||
hxin = self.rebnconvin(hx)
|
|
||||||
|
|
||||||
hx1 = self.rebnconv1(hxin)
|
|
||||||
hx2 = self.rebnconv2(hx1)
|
|
||||||
hx3 = self.rebnconv3(hx2)
|
|
||||||
|
|
||||||
hx4 = self.rebnconv4(hx3)
|
|
||||||
|
|
||||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
|
||||||
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
|
||||||
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
|
||||||
|
|
||||||
return hx1d + hxin
|
|
||||||
|
|
||||||
|
|
||||||
##### U^2-Net ####
|
|
||||||
class U2NET(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, out_ch=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.stage1 = RSU7(in_ch, 32, 64)
|
|
||||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage2 = RSU6(64, 32, 128)
|
|
||||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage3 = RSU5(128, 64, 256)
|
|
||||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage4 = RSU4(256, 128, 512)
|
|
||||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage5 = RSU4F(512, 256, 512)
|
|
||||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage6 = RSU4F(512, 256, 512)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self.stage5d = RSU4F(1024, 256, 512)
|
|
||||||
self.stage4d = RSU4(1024, 128, 256)
|
|
||||||
self.stage3d = RSU5(512, 64, 128)
|
|
||||||
self.stage2d = RSU6(256, 32, 64)
|
|
||||||
self.stage1d = RSU7(128, 16, 64)
|
|
||||||
|
|
||||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
|
||||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
|
||||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
|
||||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
|
||||||
|
|
||||||
self.outconv = nn.Conv2d(6, out_ch, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
|
|
||||||
# stage 1
|
|
||||||
hx1 = self.stage1(hx)
|
|
||||||
hx = self.pool12(hx1)
|
|
||||||
|
|
||||||
# stage 2
|
|
||||||
hx2 = self.stage2(hx)
|
|
||||||
hx = self.pool23(hx2)
|
|
||||||
|
|
||||||
# stage 3
|
|
||||||
hx3 = self.stage3(hx)
|
|
||||||
hx = self.pool34(hx3)
|
|
||||||
|
|
||||||
# stage 4
|
|
||||||
hx4 = self.stage4(hx)
|
|
||||||
hx = self.pool45(hx4)
|
|
||||||
|
|
||||||
# stage 5
|
|
||||||
hx5 = self.stage5(hx)
|
|
||||||
hx = self.pool56(hx5)
|
|
||||||
|
|
||||||
# stage 6
|
|
||||||
hx6 = self.stage6(hx)
|
|
||||||
hx6up = _upsample_like(hx6, hx5)
|
|
||||||
|
|
||||||
# -------------------- decoder --------------------
|
|
||||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
|
||||||
hx5dup = _upsample_like(hx5d, hx4)
|
|
||||||
|
|
||||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
|
||||||
hx4dup = _upsample_like(hx4d, hx3)
|
|
||||||
|
|
||||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
|
||||||
hx3dup = _upsample_like(hx3d, hx2)
|
|
||||||
|
|
||||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
|
||||||
hx2dup = _upsample_like(hx2d, hx1)
|
|
||||||
|
|
||||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
|
||||||
|
|
||||||
# side output
|
|
||||||
d1 = self.side1(hx1d)
|
|
||||||
|
|
||||||
d2 = self.side2(hx2d)
|
|
||||||
d2 = _upsample_like(d2, d1)
|
|
||||||
|
|
||||||
d3 = self.side3(hx3d)
|
|
||||||
d3 = _upsample_like(d3, d1)
|
|
||||||
|
|
||||||
d4 = self.side4(hx4d)
|
|
||||||
d4 = _upsample_like(d4, d1)
|
|
||||||
|
|
||||||
d5 = self.side5(hx5d)
|
|
||||||
d5 = _upsample_like(d5, d1)
|
|
||||||
|
|
||||||
d6 = self.side6(hx6)
|
|
||||||
d6 = _upsample_like(d6, d1)
|
|
||||||
|
|
||||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.sigmoid(d0),
|
|
||||||
torch.sigmoid(d1),
|
|
||||||
torch.sigmoid(d2),
|
|
||||||
torch.sigmoid(d3),
|
|
||||||
torch.sigmoid(d4),
|
|
||||||
torch.sigmoid(d5),
|
|
||||||
torch.sigmoid(d6),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
### U^2-Net small ###
|
|
||||||
class U2NETP(nn.Module):
|
|
||||||
def __init__(self, in_ch=3, out_ch=1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.stage1 = RSU7(in_ch, 16, 64)
|
|
||||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage2 = RSU6(64, 16, 64)
|
|
||||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage3 = RSU5(64, 16, 64)
|
|
||||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage4 = RSU4(64, 16, 64)
|
|
||||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage5 = RSU4F(64, 16, 64)
|
|
||||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
|
||||||
|
|
||||||
self.stage6 = RSU4F(64, 16, 64)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self.stage5d = RSU4F(128, 16, 64)
|
|
||||||
self.stage4d = RSU4(128, 16, 64)
|
|
||||||
self.stage3d = RSU5(128, 16, 64)
|
|
||||||
self.stage2d = RSU6(128, 16, 64)
|
|
||||||
self.stage1d = RSU7(128, 16, 64)
|
|
||||||
|
|
||||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
|
||||||
|
|
||||||
self.outconv = nn.Conv2d(6, out_ch, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
hx = x
|
|
||||||
|
|
||||||
# stage 1
|
|
||||||
hx1 = self.stage1(hx)
|
|
||||||
hx = self.pool12(hx1)
|
|
||||||
|
|
||||||
# stage 2
|
|
||||||
hx2 = self.stage2(hx)
|
|
||||||
hx = self.pool23(hx2)
|
|
||||||
|
|
||||||
# stage 3
|
|
||||||
hx3 = self.stage3(hx)
|
|
||||||
hx = self.pool34(hx3)
|
|
||||||
|
|
||||||
# stage 4
|
|
||||||
hx4 = self.stage4(hx)
|
|
||||||
hx = self.pool45(hx4)
|
|
||||||
|
|
||||||
# stage 5
|
|
||||||
hx5 = self.stage5(hx)
|
|
||||||
hx = self.pool56(hx5)
|
|
||||||
|
|
||||||
# stage 6
|
|
||||||
hx6 = self.stage6(hx)
|
|
||||||
hx6up = _upsample_like(hx6, hx5)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
|
||||||
hx5dup = _upsample_like(hx5d, hx4)
|
|
||||||
|
|
||||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
|
||||||
hx4dup = _upsample_like(hx4d, hx3)
|
|
||||||
|
|
||||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
|
||||||
hx3dup = _upsample_like(hx3d, hx2)
|
|
||||||
|
|
||||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
|
||||||
hx2dup = _upsample_like(hx2d, hx1)
|
|
||||||
|
|
||||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
|
||||||
|
|
||||||
# side output
|
|
||||||
d1 = self.side1(hx1d)
|
|
||||||
|
|
||||||
d2 = self.side2(hx2d)
|
|
||||||
d2 = _upsample_like(d2, d1)
|
|
||||||
|
|
||||||
d3 = self.side3(hx3d)
|
|
||||||
d3 = _upsample_like(d3, d1)
|
|
||||||
|
|
||||||
d4 = self.side4(hx4d)
|
|
||||||
d4 = _upsample_like(d4, d1)
|
|
||||||
|
|
||||||
d5 = self.side5(hx5d)
|
|
||||||
d5 = _upsample_like(d5, d1)
|
|
||||||
|
|
||||||
d6 = self.side6(hx6)
|
|
||||||
d6 = _upsample_like(d6, d1)
|
|
||||||
|
|
||||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.sigmoid(d0),
|
|
||||||
torch.sigmoid(d1),
|
|
||||||
torch.sigmoid(d2),
|
|
||||||
torch.sigmoid(d3),
|
|
||||||
torch.sigmoid(d4),
|
|
||||||
torch.sigmoid(d5),
|
|
||||||
torch.sigmoid(d6),
|
|
||||||
)
|
|
1
requirements-cpu.txt
Normal file
1
requirements-cpu.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
onnxruntime==1.10.0
|
1
requirements-gpu.txt
Normal file
1
requirements-gpu.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
onnxruntime-gpu==1.10.0
|
@ -1,13 +1,10 @@
|
|||||||
|
filetype==1.0.7
|
||||||
flask==1.1.2
|
flask==1.1.2
|
||||||
|
gdown==4.2.0
|
||||||
numpy==1.20.0
|
numpy==1.20.0
|
||||||
pillow==8.3.2
|
pillow==8.3.2
|
||||||
|
pymatting==1.1.5
|
||||||
scikit-image==0.19.1
|
scikit-image==0.19.1
|
||||||
torch==1.9.1
|
|
||||||
torchvision==0.10.1
|
|
||||||
waitress==1.4.4
|
|
||||||
tqdm==4.51.0
|
|
||||||
requests==2.24.0
|
|
||||||
scipy==1.5.4
|
scipy==1.5.4
|
||||||
pymatting==1.1.1
|
tqdm==4.51.0
|
||||||
filetype==1.0.7
|
waitress==1.4.4
|
||||||
matplotlib==3.5.1
|
|
||||||
|
7
setup.py
7
setup.py
@ -14,6 +14,13 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
|
|||||||
with open("requirements.txt") as f:
|
with open("requirements.txt") as f:
|
||||||
requireds = f.read().splitlines()
|
requireds = f.read().splitlines()
|
||||||
|
|
||||||
|
if os.getenv("GPU") is None:
|
||||||
|
with open("requirements-cpu.txt") as f:
|
||||||
|
requireds += f.read().splitlines()
|
||||||
|
else:
|
||||||
|
with open("requirements-gpu.txt") as f:
|
||||||
|
requireds += f.read().splitlines()
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="rembg",
|
name="rembg",
|
||||||
description="Remove image background",
|
description="Remove image background",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user