mirror of
https://git.mirrors.martin98.com/https://github.com/danielgatis/rembg
synced 2025-08-20 12:49:09 +08:00
add onnx
This commit is contained in:
parent
1ab362eb38
commit
3b18bad858
10
Dockerfile
10
Dockerfile
@ -2,16 +2,16 @@ FROM nvidia/cuda:11.4.2-cudnn8-runtime-ubuntu20.04
|
||||
|
||||
RUN apt-get update &&\
|
||||
apt-get install -y --no-install-recommends \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-dev \
|
||||
build-essential
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-dev \
|
||||
build-essential
|
||||
|
||||
WORKDIR /rembg
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN pip3 install .
|
||||
RUN GPU=1 pip3 install .
|
||||
|
||||
# First run to compile AOT & download model
|
||||
RUN rembg pixel.png >/dev/null
|
||||
|
39
README.md
39
README.md
@ -36,38 +36,19 @@ Rembg is a tool to remove images background. That is it.
|
||||
|
||||
#### *** If you want to remove background from videos try this this fork: https://github.com/ecsplendid/rembg-greenscreen ***
|
||||
|
||||
### Requirements
|
||||
|
||||
* python 3.8 or newer
|
||||
|
||||
* torch and torchvision stable version (https://pytorch.org)
|
||||
|
||||
#### How to install torch/torchvision
|
||||
|
||||
Go to https://pytorch.org and scrool down to `INSTALL PYTORCH` section and follow the instructions.
|
||||
|
||||
For example:
|
||||
```
|
||||
PyTorch Build: Stable (1.7.1)
|
||||
Your OS: Windows
|
||||
Package: Pip
|
||||
Language: Python
|
||||
CUDA: None
|
||||
```
|
||||
|
||||
The install cmd is:
|
||||
```
|
||||
pip install torch==1.7.1+cpu torchvision==0.8.2+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
Install it from pypi
|
||||
|
||||
CPU support:
|
||||
```bash
|
||||
pip install rembg
|
||||
```
|
||||
|
||||
GPU support:
|
||||
```bash
|
||||
GPU=1 pip install rembg
|
||||
```
|
||||
|
||||
### Usage as a cli
|
||||
|
||||
Remove the background from a remote image
|
||||
@ -85,14 +66,6 @@ Remove the background from all images in a folder
|
||||
rembg -p path/to/input path/to/output
|
||||
```
|
||||
|
||||
### Add a custom model
|
||||
|
||||
Copy the `custom-model.pth` file to `~/.u2net` and run:
|
||||
|
||||
```bash
|
||||
curl -s http://input.png | rembg -m custom-model > output.png
|
||||
```
|
||||
|
||||
### Usage as a server
|
||||
|
||||
Start the server
|
||||
|
@ -1,4 +1,3 @@
|
||||
import functools
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
@ -68,7 +67,6 @@ def naive_cutout(img, mask):
|
||||
return cutout
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_model(model_name):
|
||||
if model_name == "u2netp":
|
||||
return load_model(model_name="u2netp")
|
||||
|
13
rembg/cli.py
13
rembg/cli.py
@ -10,17 +10,6 @@ from .bg import remove
|
||||
|
||||
|
||||
def main():
|
||||
model_path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net")),
|
||||
)
|
||||
model_choices = [
|
||||
os.path.splitext(os.path.basename(x))[0]
|
||||
for x in set(glob.glob(model_path + "/*"))
|
||||
]
|
||||
|
||||
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
|
||||
|
||||
ap = argparse.ArgumentParser()
|
||||
|
||||
ap.add_argument(
|
||||
@ -28,7 +17,7 @@ def main():
|
||||
"--model",
|
||||
default="u2net",
|
||||
type=str,
|
||||
choices=model_choices,
|
||||
choices=["u2net", "u2netp", "u2net_human_seg"],
|
||||
help="The model name.",
|
||||
)
|
||||
|
||||
|
@ -1,326 +0,0 @@
|
||||
# data loader
|
||||
|
||||
import random
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from skimage import color, io, transform
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms, utils
|
||||
|
||||
|
||||
# ==========================dataset load==========================
|
||||
class RescaleT:
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size, int):
|
||||
if h > w:
|
||||
new_h, new_w = self.output_size * h / w, self.output_size
|
||||
else:
|
||||
new_h, new_w = self.output_size, self.output_size * w / h
|
||||
else:
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
||||
# img = transform.resize(image,(new_h,new_w),mode='constant')
|
||||
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
|
||||
|
||||
img = transform.resize(
|
||||
image, (self.output_size, self.output_size), mode="constant"
|
||||
)
|
||||
lbl = transform.resize(
|
||||
label,
|
||||
(self.output_size, self.output_size),
|
||||
mode="constant",
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
return {"imidx": imidx, "image": img, "label": lbl}
|
||||
|
||||
|
||||
class Rescale:
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
if random.random() >= 0.5:
|
||||
image = image[::-1]
|
||||
label = label[::-1]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size, int):
|
||||
if h > w:
|
||||
new_h, new_w = self.output_size * h / w, self.output_size
|
||||
else:
|
||||
new_h, new_w = self.output_size, self.output_size * w / h
|
||||
else:
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
|
||||
img = transform.resize(image, (new_h, new_w), mode="constant")
|
||||
lbl = transform.resize(
|
||||
label, (new_h, new_w), mode="constant", order=0, preserve_range=True
|
||||
)
|
||||
|
||||
return {"imidx": imidx, "image": img, "label": lbl}
|
||||
|
||||
|
||||
class RandomCrop:
|
||||
def __init__(self, output_size):
|
||||
assert isinstance(output_size, (int, tuple))
|
||||
if isinstance(output_size, int):
|
||||
self.output_size = (output_size, output_size)
|
||||
else:
|
||||
assert len(output_size) == 2
|
||||
self.output_size = output_size
|
||||
|
||||
def __call__(self, sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
if random.random() >= 0.5:
|
||||
image = image[::-1]
|
||||
label = label[::-1]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
new_h, new_w = self.output_size
|
||||
|
||||
top = np.random.randint(0, h - new_h)
|
||||
left = np.random.randint(0, w - new_w)
|
||||
|
||||
image = image[top : top + new_h, left : left + new_w]
|
||||
label = label[top : top + new_h, left : left + new_w]
|
||||
|
||||
return {"imidx": imidx, "image": image, "label": label}
|
||||
|
||||
|
||||
class ToTensor:
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
|
||||
def __call__(self, sample):
|
||||
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
image = image / np.max(image)
|
||||
if np.max(label) < 1e-6:
|
||||
label = label
|
||||
else:
|
||||
label = label / np.max(label)
|
||||
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
||||
else:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
||||
|
||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
||||
|
||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
||||
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {
|
||||
"imidx": torch.from_numpy(imidx),
|
||||
"image": torch.from_numpy(tmpImg),
|
||||
"label": torch.from_numpy(tmpLbl),
|
||||
}
|
||||
|
||||
|
||||
class ToTensorLab:
|
||||
"""Convert ndarrays in sample to Tensors."""
|
||||
|
||||
def __init__(self, flag=0):
|
||||
self.flag = flag
|
||||
|
||||
def __call__(self, sample):
|
||||
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
if np.max(label) < 1e-6:
|
||||
label = label
|
||||
else:
|
||||
label = label / np.max(label)
|
||||
|
||||
# change the color space
|
||||
if self.flag == 2: # with rgb and Lab colors
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
|
||||
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
if image.shape[2] == 1:
|
||||
tmpImgt[:, :, 0] = image[:, :, 0]
|
||||
tmpImgt[:, :, 1] = image[:, :, 0]
|
||||
tmpImgt[:, :, 2] = image[:, :, 0]
|
||||
else:
|
||||
tmpImgt = image
|
||||
tmpImgtl = color.rgb2lab(tmpImgt)
|
||||
|
||||
# nomalize image to range [0,1]
|
||||
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
|
||||
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
|
||||
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
|
||||
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
|
||||
)
|
||||
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
|
||||
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
|
||||
)
|
||||
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
|
||||
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
|
||||
)
|
||||
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
|
||||
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
|
||||
)
|
||||
|
||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
||||
|
||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
|
||||
tmpImg[:, :, 0]
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
|
||||
tmpImg[:, :, 1]
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
|
||||
tmpImg[:, :, 2]
|
||||
)
|
||||
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
|
||||
tmpImg[:, :, 3]
|
||||
)
|
||||
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
|
||||
tmpImg[:, :, 4]
|
||||
)
|
||||
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
|
||||
tmpImg[:, :, 5]
|
||||
)
|
||||
|
||||
elif self.flag == 1: # with Lab color
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = image[:, :, 0]
|
||||
tmpImg[:, :, 1] = image[:, :, 0]
|
||||
tmpImg[:, :, 2] = image[:, :, 0]
|
||||
else:
|
||||
tmpImg = image
|
||||
|
||||
tmpImg = color.rgb2lab(tmpImg)
|
||||
|
||||
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
|
||||
|
||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
|
||||
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
|
||||
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
|
||||
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
|
||||
)
|
||||
|
||||
tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
|
||||
tmpImg[:, :, 0]
|
||||
)
|
||||
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
|
||||
tmpImg[:, :, 1]
|
||||
)
|
||||
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
|
||||
tmpImg[:, :, 2]
|
||||
)
|
||||
|
||||
else: # with rgb color
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
image = image / np.max(image)
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
||||
else:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
||||
|
||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
||||
|
||||
# change the r,g,b to b,r,g from [0,255] to [0,1]
|
||||
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {
|
||||
"imidx": torch.from_numpy(imidx),
|
||||
"image": torch.from_numpy(tmpImg),
|
||||
"label": torch.from_numpy(tmpLbl),
|
||||
}
|
||||
|
||||
|
||||
class SalObjDataset(Dataset):
|
||||
def __init__(self, img_name_list, lbl_name_list, transform=None):
|
||||
# self.root_dir = root_dir
|
||||
# self.image_name_list = glob.glob(image_dir+'*.png')
|
||||
# self.label_name_list = glob.glob(label_dir+'*.png')
|
||||
self.image_name_list = img_name_list
|
||||
self.label_name_list = lbl_name_list
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_name_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
|
||||
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
|
||||
|
||||
image = io.imread(self.image_name_list[idx])
|
||||
imname = self.image_name_list[idx]
|
||||
imidx = np.array([idx])
|
||||
|
||||
if 0 == len(self.label_name_list):
|
||||
label_3 = np.zeros(image.shape)
|
||||
else:
|
||||
label_3 = io.imread(self.label_name_list[idx])
|
||||
|
||||
label = np.zeros(label_3.shape[0:2])
|
||||
if 3 == len(label_3.shape):
|
||||
label = label_3[:, :, 0]
|
||||
elif 2 == len(label_3.shape):
|
||||
label = label_3
|
||||
|
||||
if 3 == len(image.shape) and 2 == len(label.shape):
|
||||
label = label[:, :, np.newaxis]
|
||||
elif 2 == len(image.shape) and 2 == len(label.shape):
|
||||
image = image[:, :, np.newaxis]
|
||||
label = label[:, :, np.newaxis]
|
||||
|
||||
sample = {"imidx": imidx, "image": image, "label": label}
|
||||
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
|
||||
return sample
|
219
rembg/detect.py
219
rembg/detect.py
@ -1,139 +1,103 @@
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
from hashlib import md5
|
||||
|
||||
import gdown
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
import onnxruntime as ort
|
||||
from PIL import Image
|
||||
from skimage import transform
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from .data_loader import RescaleT, ToTensorLab
|
||||
from .u2net import U2NET, U2NETP
|
||||
|
||||
|
||||
def download_file_from_google_drive(id, fname, destination):
|
||||
head, tail = os.path.split(destination)
|
||||
os.makedirs(head, exist_ok=True)
|
||||
|
||||
URL = "https://docs.google.com/uc?export=download"
|
||||
|
||||
session = requests.Session()
|
||||
response = session.get(URL, params={"id": id}, stream=True)
|
||||
|
||||
token = None
|
||||
for key, value in response.cookies.items():
|
||||
if key.startswith("download_warning"):
|
||||
token = value
|
||||
break
|
||||
|
||||
if token:
|
||||
params = {"id": id, "confirm": token}
|
||||
response = session.get(URL, params=params, stream=True)
|
||||
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
|
||||
with open(destination, "wb") as file, tqdm(
|
||||
desc=f"Downloading {tail} to {head}",
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
SESSIONS = {}
|
||||
|
||||
|
||||
def load_model(model_name: str = "u2net"):
|
||||
hashfile = lambda f: md5(open(f, "rb").read()).hexdigest()
|
||||
path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".onnx")),
|
||||
)
|
||||
|
||||
if model_name == "u2netp":
|
||||
net = U2NETP(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hashfile(path) != "e4f636406ca4e2af789941e7f139ee2e"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy",
|
||||
"u2netp.pth",
|
||||
path,
|
||||
)
|
||||
|
||||
md5 = "8e83ca70e441ab06c318d82300c84806"
|
||||
url = "https://drive.google.com/uc?id=1tNuFmLv0TSNDjYIkjEdeH1IWKQdUA4HR"
|
||||
elif model_name == "u2net":
|
||||
net = U2NET(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hashfile(path) != "347c3d51b01528e5c6c071e3cff1cb55"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ",
|
||||
"pth",
|
||||
path,
|
||||
)
|
||||
|
||||
md5 = "60024c5c889badc19c04ad937298a77b"
|
||||
url = "https://drive.google.com/uc?id=1tCU5MM1LhRgGou5OpmpjBQbSrYIUoYab"
|
||||
elif model_name == "u2net_human_seg":
|
||||
net = U2NET(3, 1)
|
||||
path = os.environ.get(
|
||||
"U2NET_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")),
|
||||
)
|
||||
if (
|
||||
not os.path.exists(path)
|
||||
or hashfile(path) != "09fb4e49b7f785c9f855baf94916840a"
|
||||
):
|
||||
download_file_from_google_drive(
|
||||
"1-Yg0cxgrNhHP-016FPdp902BR-kSsA4P",
|
||||
"u2net_human_seg.pth",
|
||||
path,
|
||||
)
|
||||
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
||||
url = "https://drive.google.com/uc?id=1ZfqwVxu-1XWC1xU1GHIP-FM_Knd_AX5j"
|
||||
else:
|
||||
print("Choose between u2net, u2net_human_seg or u2netp", file=sys.stderr)
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(path))
|
||||
net.to(torch.device("cuda"))
|
||||
else:
|
||||
net.load_state_dict(
|
||||
torch.load(
|
||||
path,
|
||||
map_location="cpu",
|
||||
)
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(
|
||||
errno.ENOENT, os.strerror(errno.ENOENT), model_name + ".pth"
|
||||
)
|
||||
if SESSIONS.get(md5) is None:
|
||||
gdown.cached_download(url, path, md5=md5, quiet=False)
|
||||
SESSIONS[md5] = ort.InferenceSession(path)
|
||||
|
||||
net.eval()
|
||||
|
||||
return net
|
||||
return SESSIONS[md5]
|
||||
|
||||
|
||||
def norm_pred(d):
|
||||
ma = torch.max(d)
|
||||
mi = torch.min(d)
|
||||
ma = np.max(d)
|
||||
mi = np.min(d)
|
||||
dn = (d - mi) / (ma - mi)
|
||||
|
||||
return dn
|
||||
|
||||
|
||||
def rescale(sample, output_size):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(output_size, int):
|
||||
if h > w:
|
||||
new_h, new_w = output_size * h / w, output_size
|
||||
else:
|
||||
new_h, new_w = output_size, output_size * w / h
|
||||
else:
|
||||
new_h, new_w = output_size
|
||||
|
||||
new_h, new_w = int(new_h), int(new_w)
|
||||
|
||||
img = transform.resize(image, (output_size, output_size), mode="constant")
|
||||
lbl = transform.resize(
|
||||
label,
|
||||
(output_size, output_size),
|
||||
mode="constant",
|
||||
order=0,
|
||||
preserve_range=True,
|
||||
)
|
||||
|
||||
return {"imidx": imidx, "image": img, "label": lbl}
|
||||
|
||||
|
||||
def color(sample):
|
||||
imidx, image, label = sample["imidx"], sample["image"], sample["label"]
|
||||
|
||||
tmpLbl = np.zeros(label.shape)
|
||||
|
||||
if np.max(label) < 1e-6:
|
||||
label = label
|
||||
else:
|
||||
label = label / np.max(label)
|
||||
|
||||
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
|
||||
image = image / np.max(image)
|
||||
if image.shape[2] == 1:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
|
||||
else:
|
||||
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
|
||||
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
|
||||
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
|
||||
|
||||
tmpLbl[:, :, 0] = label[:, :, 0]
|
||||
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||
tmpLbl = label.transpose((2, 0, 1))
|
||||
|
||||
return {"imidx": imidx, "image": tmpImg, "label": tmpLbl}
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
label_3 = np.zeros(image.shape)
|
||||
label = np.zeros(label_3.shape[0:2])
|
||||
@ -149,34 +113,23 @@ def preprocess(image):
|
||||
image = image[:, :, np.newaxis]
|
||||
label = label[:, :, np.newaxis]
|
||||
|
||||
transform = transforms.Compose([RescaleT(320), ToTensorLab(flag=0)])
|
||||
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
|
||||
sample = {"imidx": np.array([0]), "image": image, "label": label}
|
||||
sample = rescale(sample, 320)
|
||||
sample = color(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
def predict(net, item):
|
||||
|
||||
def predict(ort_session, item):
|
||||
sample = preprocess(item)
|
||||
inputs_test = np.expand_dims(sample["image"], 0).astype(np.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: inputs_test}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
inputs_test = torch.cuda.FloatTensor(
|
||||
sample["image"].unsqueeze(0).cuda().float()
|
||||
)
|
||||
else:
|
||||
inputs_test = torch.FloatTensor(sample["image"].unsqueeze(0).float())
|
||||
d1 = ort_outs[0]
|
||||
pred = d1[:, 0, :, :]
|
||||
predict = np.squeeze(norm_pred(pred))
|
||||
img = Image.fromarray(predict * 255).convert("RGB")
|
||||
|
||||
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)
|
||||
|
||||
pred = d1[:, 0, :, :]
|
||||
predict = norm_pred(pred)
|
||||
|
||||
predict = predict.squeeze()
|
||||
predict_np = predict.cpu().detach().numpy()
|
||||
img = Image.fromarray(predict_np * 255).convert("RGB")
|
||||
|
||||
del d1, d2, d3, d4, d5, d6, d7, pred, predict, predict_np, inputs_test, sample
|
||||
|
||||
return img
|
||||
return img
|
||||
|
@ -46,16 +46,7 @@ def index():
|
||||
height = request.args.get("height", type=int)
|
||||
|
||||
model = request.values.get("model", type=str, default="u2net")
|
||||
model_path = os.environ.get(
|
||||
"U2NETP_PATH",
|
||||
os.path.expanduser(os.path.join("~", ".u2net")),
|
||||
)
|
||||
model_choices = [
|
||||
os.path.splitext(os.path.basename(x))[0]
|
||||
for x in set(glob.glob(model_path + "/*"))
|
||||
]
|
||||
|
||||
model_choices = list(set(model_choices + ["u2net", "u2netp", "u2net_human_seg"]))
|
||||
model_choices = ["u2net", "u2netp", "u2net_human_seg"]
|
||||
|
||||
if model not in model_choices:
|
||||
return {
|
||||
|
541
rembg/u2net.py
541
rembg/u2net.py
@ -1,541 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
||||
super().__init__()
|
||||
|
||||
self.conv_s1 = nn.Conv2d(
|
||||
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
||||
)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
|
||||
return xout
|
||||
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src, tar):
|
||||
|
||||
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super().__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
|
||||
hx6 = self.rebnconv6(hx)
|
||||
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
||||
hx6dup = _upsample_like(hx6d, hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-6 ###
|
||||
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super().__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
|
||||
hx5 = self.rebnconv5(hx)
|
||||
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-5 ###
|
||||
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super().__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
|
||||
hx4 = self.rebnconv4(hx)
|
||||
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4 ###
|
||||
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super().__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
|
||||
hx3 = self.rebnconv3(hx)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
### RSU-4F ###
|
||||
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super().__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
||||
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
||||
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
hxin = self.rebnconvin(hx)
|
||||
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
||||
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
##### U^2-Net ####
|
||||
class U2NET(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super().__init__()
|
||||
|
||||
self.stage1 = RSU7(in_ch, 32, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 32, 128)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(128, 64, 256)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(256, 128, 512)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(512, 256, 512)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(512, 256, 512)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024, 256, 512)
|
||||
self.stage4d = RSU4(1024, 128, 256)
|
||||
self.stage3d = RSU5(512, 64, 128)
|
||||
self.stage2d = RSU6(256, 32, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
|
||||
self.outconv = nn.Conv2d(6, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
# -------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
||||
|
||||
return (
|
||||
torch.sigmoid(d0),
|
||||
torch.sigmoid(d1),
|
||||
torch.sigmoid(d2),
|
||||
torch.sigmoid(d3),
|
||||
torch.sigmoid(d4),
|
||||
torch.sigmoid(d5),
|
||||
torch.sigmoid(d6),
|
||||
)
|
||||
|
||||
|
||||
### U^2-Net small ###
|
||||
class U2NETP(nn.Module):
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super().__init__()
|
||||
|
||||
self.stage1 = RSU7(in_ch, 16, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage2 = RSU6(64, 16, 64)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage3 = RSU5(64, 16, 64)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage4 = RSU4(64, 16, 64)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage5 = RSU4F(64, 16, 64)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
|
||||
self.stage6 = RSU4F(64, 16, 64)
|
||||
|
||||
# decoder
|
||||
self.stage5d = RSU4F(128, 16, 64)
|
||||
self.stage4d = RSU4(128, 16, 64)
|
||||
self.stage3d = RSU5(128, 16, 64)
|
||||
self.stage2d = RSU6(128, 16, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
|
||||
self.outconv = nn.Conv2d(6, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
|
||||
# stage 1
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
|
||||
# stage 2
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
# stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
|
||||
# stage 4
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
|
||||
# stage 5
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
|
||||
# stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
# decoder
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
|
||||
# side output
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
||||
|
||||
return (
|
||||
torch.sigmoid(d0),
|
||||
torch.sigmoid(d1),
|
||||
torch.sigmoid(d2),
|
||||
torch.sigmoid(d3),
|
||||
torch.sigmoid(d4),
|
||||
torch.sigmoid(d5),
|
||||
torch.sigmoid(d6),
|
||||
)
|
1
requirements-cpu.txt
Normal file
1
requirements-cpu.txt
Normal file
@ -0,0 +1 @@
|
||||
onnxruntime==1.10.0
|
1
requirements-gpu.txt
Normal file
1
requirements-gpu.txt
Normal file
@ -0,0 +1 @@
|
||||
onnxruntime-gpu==1.10.0
|
@ -1,13 +1,10 @@
|
||||
filetype==1.0.7
|
||||
flask==1.1.2
|
||||
gdown==4.2.0
|
||||
numpy==1.20.0
|
||||
pillow==8.3.2
|
||||
pymatting==1.1.5
|
||||
scikit-image==0.19.1
|
||||
torch==1.9.1
|
||||
torchvision==0.10.1
|
||||
waitress==1.4.4
|
||||
tqdm==4.51.0
|
||||
requests==2.24.0
|
||||
scipy==1.5.4
|
||||
pymatting==1.1.1
|
||||
filetype==1.0.7
|
||||
matplotlib==3.5.1
|
||||
tqdm==4.51.0
|
||||
waitress==1.4.4
|
||||
|
7
setup.py
7
setup.py
@ -14,6 +14,13 @@ long_description = (here / "README.md").read_text(encoding="utf-8")
|
||||
with open("requirements.txt") as f:
|
||||
requireds = f.read().splitlines()
|
||||
|
||||
if os.getenv("GPU") is None:
|
||||
with open("requirements-cpu.txt") as f:
|
||||
requireds += f.read().splitlines()
|
||||
else:
|
||||
with open("requirements-gpu.txt") as f:
|
||||
requireds += f.read().splitlines()
|
||||
|
||||
setup(
|
||||
name="rembg",
|
||||
description="Remove image background",
|
||||
|
Loading…
x
Reference in New Issue
Block a user