This commit is contained in:
Nathan Qin 2020-05-09 15:32:41 -06:00
parent 4f62ad454c
commit e660dc6c5a
67 changed files with 1111 additions and 6 deletions

View File

@ -1,29 +1,52 @@
## U^2-Net ## U^2-Net
The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection, [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/)." The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection, [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/)."
## (2020-May-04) The code will be released soon (before 2020-May-10)! ## U^2-Net Results
![U^2-Net Results](figures/u2netqual.png)
__Contact__: xuebin[at]ualberta[dot]ca __Contact__: xuebin[at]ualberta[dot]ca
## Our previous work: [BASNet (CVPR 2019)](https://github.com/NathanUA/BASNet) ## Our previous work: [BASNet (CVPR 2019)](https://github.com/NathanUA/BASNet)
## Required libraries
Python 3.6
numpy 1.15.2
scikit-image 0.14.0
PIL 5.2.0
PyTorch 0.4.0
torchvision 0.2.1
glob
## Usage
1. Clone this repo
```
git clone https://github.com/NathanUA/U-2-Net.git
```
2. Download the pre-trained model [u2net.pth 173.6 MB](https://drive.google.com/file/d/1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ/view?usp=sharing) or [u2netp.pth 4.7 MB](https://drive.google.com/file/d/1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy/view?usp=sharing) and put it into the dirctory './saved_models/u2net/' and './saved_models/u2netp/'
3. Cd to the directory 'U-2-Net', run the train or inference process by command: ```python u2net_train.py```
or ```python u2net_test.py``` respectively. The 'model_name' in both files can be changed to 'u2net' or 'u2netp' for using different models.
We also provide the predicted saliency maps ([u2net results](https://drive.google.com/file/d/1mZFWlS4WygWh1eVI8vK2Ad9LrPq4Hp5v/view?usp=sharing),[u2netp results](https://drive.google.com/file/d/1j2pU7vyhOO30C2S_FJuRdmAmMt3-xmjD/view?usp=sharing)) for datasets SOD, ECSSD, DUT-OMRON, PASCAL-S, HKU-IS and DUTS-TE.
## U^2-Net Architecture ## U^2-Net Architecture
![U^2-Net architecture](U2NETPR.png) ![U^2-Net architecture](figures/U2NETPR.png)
## Quantitative Comparison ## Quantitative Comparison
![Quantitative Comparison](quan_1.png) ![Quantitative Comparison](figures/quan_1.png)
![Quantitative Comparison](quan_2.png) ![Quantitative Comparison](figures/quan_2.png)
## Qualitative Comparison ## Qualitative Comparison
![Qualitative Comparison](qual.png?raw=true) ![Qualitative Comparison](figures/qual.png?raw=true)
## Citation ## Citation

259
data_loader.py Normal file
View File

@ -0,0 +1,259 @@
# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
#==========================dataset load==========================
class RescaleT(object):
def __init__(self,output_size):
assert isinstance(output_size,(int,tuple))
self.output_size = output_size
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size,int):
if h > w:
new_h, new_w = self.output_size*h/w,self.output_size
else:
new_h, new_w = self.output_size,self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
return {'imidx':imidx, 'image':img,'label':lbl}
class Rescale(object):
def __init__(self,output_size):
assert isinstance(output_size,(int,tuple))
self.output_size = output_size
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size,int):
if h > w:
new_h, new_w = self.output_size*h/w,self.output_size
else:
new_h, new_w = self.output_size,self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image,(new_h,new_w),mode='constant')
lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
return {'imidx':imidx, 'image':img,'label':lbl}
class RandomCrop(object):
def __init__(self,output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
label = label[top: top + new_h, left: left + new_w]
return {'imidx':imidx,'image':image, 'label':label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
tmpLbl = np.zeros(label.shape)
image = image/np.max(image)
if(np.max(label)<1e-6):
label = label
else:
label = label/np.max(label)
if image.shape[2]==1:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
else:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
tmpLbl[:,:,0] = label[:,:,0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
class ToTensorLab(object):
"""Convert ndarrays in sample to Tensors."""
def __init__(self,flag=0):
self.flag = flag
def __call__(self, sample):
imidx, image, label =sample['imidx'], sample['image'], sample['label']
tmpLbl = np.zeros(label.shape)
if(np.max(label)<1e-6):
label = label
else:
label = label/np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0],image.shape[1],6))
tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
if image.shape[2]==1:
tmpImgt[:,:,0] = image[:,:,0]
tmpImgt[:,:,1] = image[:,:,0]
tmpImgt[:,:,2] = image[:,:,0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
elif self.flag == 1: #with Lab color
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
if image.shape[2]==1:
tmpImg[:,:,0] = image[:,:,0]
tmpImg[:,:,1] = image[:,:,0]
tmpImg[:,:,2] = image[:,:,0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
else: # with rgb color
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
image = image/np.max(image)
if image.shape[2]==1:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
else:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
tmpLbl[:,:,0] = label[:,:,0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
class SalObjDataset(Dataset):
def __init__(self,img_name_list,lbl_name_list,transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self,idx):
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
if(0==len(self.label_name_list)):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if(3==len(label_3.shape)):
label = label_3[:,:,0]
elif(2==len(label_3.shape)):
label = label_3
if(3==len(image.shape) and 2==len(label.shape)):
label = label[:,:,np.newaxis]
elif(2==len(image.shape) and 2==len(label.shape)):
image = image[:,:,np.newaxis]
label = label[:,:,np.newaxis]
sample = {'imidx':imidx, 'image':image, 'label':label}
if self.transform:
sample = self.transform(sample)
return sample

View File

Before

Width:  |  Height:  |  Size: 826 KiB

After

Width:  |  Height:  |  Size: 826 KiB

View File

Before

Width:  |  Height:  |  Size: 474 KiB

After

Width:  |  Height:  |  Size: 474 KiB

View File

Before

Width:  |  Height:  |  Size: 316 KiB

After

Width:  |  Height:  |  Size: 316 KiB

View File

Before

Width:  |  Height:  |  Size: 320 KiB

After

Width:  |  Height:  |  Size: 320 KiB

BIN
figures/u2netqual.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 697 KiB

2
model/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .u2net import U2NET
from .u2net import U2NETP

Binary file not shown.

Binary file not shown.

541
model/u2net.py Normal file
View File

@ -0,0 +1,541 @@
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
class REBNCONV(nn.Module):
def __init__(self,in_ch=3,out_ch=3,dirate=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6up = self.upscore2(hx6d)
# print(hx6up.shape,hx5.shape)
hx5d = self.rebnconv5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = self.upscore2(hx5d)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = self.upscore2(hx4d)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = self.upscore2(hx3d)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(U2NET,self).__init__()
self.stage1 = RSU7(in_ch,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,1,3,padding=1)
self.side2 = nn.Conv2d(64,1,3,padding=1)
self.side3 = nn.Conv2d(128,1,3,padding=1)
self.side4 = nn.Conv2d(256,1,3,padding=1)
self.side5 = nn.Conv2d(512,1,3,padding=1)
self.side6 = nn.Conv2d(512,1,3,padding=1)
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.outconv = nn.Conv2d(6,1,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = self.upscore2(hx6)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = self.upscore2(d2)
d3 = self.side3(hx3d)
d3 = self.upscore3(d3)
d4 = self.side4(hx4d)
d4 = self.upscore4(d4)
d5 = self.side5(hx5d)
d5 = self.upscore5(d5)
d6 = self.side6(hx6)
d6 = self.upscore6(d6)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(U2NETP,self).__init__()
self.stage1 = RSU7(in_ch,16,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(64,16,64)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(64,16,64)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(64,16,64)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(64,16,64)
# decoder
self.stage5d = RSU4F(128,16,64)
self.stage4d = RSU4(128,16,64)
self.stage3d = RSU5(128,16,64)
self.stage2d = RSU6(128,16,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,1,3,padding=1)
self.side2 = nn.Conv2d(64,1,3,padding=1)
self.side3 = nn.Conv2d(64,1,3,padding=1)
self.side4 = nn.Conv2d(64,1,3,padding=1)
self.side5 = nn.Conv2d(64,1,3,padding=1)
self.side6 = nn.Conv2d(64,1,3,padding=1)
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.outconv = nn.Conv2d(6,1,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = self.upscore2(hx6)
#decoder
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = self.upscore2(d2)
d3 = self.side3(hx3d)
d3 = self.upscore3(d3)
d4 = self.side4(hx4d)
d4 = self.upscore4(d4)
d5 = self.side5(hx5d)
d5 = self.upscore5(d5)
d6 = self.side6(hx6)
d6 = self.upscore6(d6)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
# d00 = d0 + self.refconv(d0)
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

Binary file not shown.

After

Width:  |  Height:  |  Size: 325 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 832 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 229 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 199 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 168 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 135 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 267 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 219 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 456 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 572 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

116
u2net_test.py Normal file
View File

@ -0,0 +1,116 @@
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name,pred,d_dir):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split("/")[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
pb_np = np.array(imo)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir+imidx+'.png')
def main():
# --------- 1. get image path and name ---------
model_name='u2net'#u2netp
image_dir = './test_data/test_images/'
prediction_dir = './test_data/' + model_name + '_results/'
model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth'
img_name_list = glob.glob(image_dir + '*')
print(img_name_list)
# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
# --------- 3. model define ---------
if(model_name=='u2net'):
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):
print("inferencing:",img_name_list[i_test].split("/")[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)
# save results to test_results folder
save_output(img_name_list[i_test],pred,prediction_dir)
del d1,d2,d3,d4,d5,d6,d7
if __name__ == "__main__":
main()

164
u2net_train.py Normal file
View File

@ -0,0 +1,164 @@
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms
import numpy as np
import glob
from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET
from model import U2NETP
# ------- 1. define loss function --------
bce_loss = nn.BCELoss(size_average=True)
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0]))
return loss0, loss
# ------- 2. set the directory of training dataset --------
model_name = 'u2net' #'u2netp'
data_dir = './train_data/'
tra_image_dir = 'DUTS/DUTS-TR/DUTS-TR/im_aug/'
tra_label_dir = 'DUTS/DUTS-TR/DUTS-TR/gt_aug/'
image_ext = '.jpg'
label_ext = '.png'
model_dir = './saved_models/' + model_name +'/'
epoch_num = 100000
batch_size_train = 12
batch_size_val = 1
train_num = 0
val_num = 0
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split("/")[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
train_num = len(tra_img_name_list)
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
net = U2NET(3, 1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)
if torch.cuda.is_available():
net.cuda()
# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 10 # save the model every 2000 iterations
for epoch in range(0, epoch_num):
net.train()
for i, data in enumerate(salobj_dataloader):
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
inputs, labels = data['image'], data['label']
inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# y zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
loss.backward()
optimizer.step()
# # print statistics
running_loss += loss.data[0]
running_tar_loss += loss2.data[0]
# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
if ite_num % save_frq == 0:
torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
running_loss = 0.0
running_tar_loss = 0.0
net.train() # resume train
ite_num4val = 0
if __name__ == "__main__":
main()