input size and readme

This commit is contained in:
Nathan Qin 2020-05-16 13:49:06 -06:00
parent 89d9a2e60d
commit 04bd5f1f63
3 changed files with 72 additions and 74 deletions

View File

@ -1,10 +1,18 @@
# U^2-Net
The code for our newly accepted paper in Pattern Recognition 2020:
## U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection, [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/).
## [U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author), [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/).
__Contact__: xuebin[at]ualberta[dot]ca
## News
(2020-May-16) The official paper of U^2-Net (U square net) [PDF in elsevier](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author) is now available. If you are not able to access that, please feel free to drop me an email.
(2020-May-16) We fix the upsampling issue of the network. Now, the model should be able to handle *arbitrary* input size. (Tips: This modification is to facilitate the retraining of U^2-Net on your own datasets. When using our pre-trained model on SOD datasets, please keep the input size as 320x32 to guarantee the performance.)
(2020-May-16) We highly appreciate Cyril Diagne for building this fantastic AR project: [AR Copy and Paste](https://github.com/cyrildiagne/ar-cutpaste) using our [BASNet](https://github.com/NathanUA/BASNet) and U^2-Net. The [demo video](https://twitter.com/cyrildiagne/status/1256916982764646402) has achieved over 5M views, which is phenomenal and show us more probabilities of SOD.
## U^2-Net Results (173.6 MB)
![U^2-Net Results](figures/u2netqual.png)

View File

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
#==========================dataset load==========================
class RescaleT(object):
@ -50,6 +51,10 @@ class Rescale(object):
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label']
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
if isinstance(self.output_size,int):

View File

@ -18,6 +18,14 @@ class REBNCONV(nn.Module):
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):
@ -52,8 +60,6 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
@ -79,19 +85,19 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module):
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6up = self.upscore2(hx6d)
# print(hx6up.shape,hx5.shape)
hx5d = self.rebnconv5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d)
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -127,8 +133,6 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
@ -153,16 +157,16 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module):
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = self.upscore2(hx5d)
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -194,8 +198,6 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
@ -216,13 +218,13 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module):
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = self.upscore2(hx4d)
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -250,8 +252,6 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x):
hx = x
@ -269,10 +269,10 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module):
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = self.upscore2(hx3d)
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -345,20 +345,14 @@ class U2NET(nn.Module):
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,1,3,padding=1)
self.side2 = nn.Conv2d(64,1,3,padding=1)
self.side3 = nn.Conv2d(128,1,3,padding=1)
self.side4 = nn.Conv2d(256,1,3,padding=1)
self.side5 = nn.Conv2d(512,1,3,padding=1)
self.side6 = nn.Conv2d(512,1,3,padding=1)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.outconv = nn.Conv2d(6,1,1)
self.outconv = nn.Conv2d(6,out_ch,1)
def forward(self,x):
@ -372,8 +366,6 @@ class U2NET(nn.Module):
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
@ -388,20 +380,20 @@ class U2NET(nn.Module):
#stage 6
hx6 = self.stage6(hx)
hx6up = self.upscore2(hx6)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d)
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
@ -410,19 +402,19 @@ class U2NET(nn.Module):
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = self.upscore2(d2)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = self.upscore3(d3)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = self.upscore4(d4)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = self.upscore5(d5)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = self.upscore6(d6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
@ -458,20 +450,14 @@ class U2NETP(nn.Module):
self.stage2d = RSU6(128,16,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,1,3,padding=1)
self.side2 = nn.Conv2d(64,1,3,padding=1)
self.side3 = nn.Conv2d(64,1,3,padding=1)
self.side4 = nn.Conv2d(64,1,3,padding=1)
self.side5 = nn.Conv2d(64,1,3,padding=1)
self.side6 = nn.Conv2d(64,1,3,padding=1)
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.outconv = nn.Conv2d(6,1,1)
self.outconv = nn.Conv2d(6,out_ch,1)
def forward(self,x):
@ -499,20 +485,20 @@ class U2NETP(nn.Module):
#stage 6
hx6 = self.stage6(hx)
hx6up = self.upscore2(hx6)
hx6up = _upsample_like(hx6,hx5)
#decoder
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d)
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d)
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d)
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d)
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
@ -521,21 +507,20 @@ class U2NETP(nn.Module):
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = self.upscore2(d2)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = self.upscore3(d3)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = self.upscore4(d4)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = self.upscore5(d5)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = self.upscore6(d6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
# d00 = d0 + self.refconv(d0)
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)