input size and readme

This commit is contained in:
Nathan Qin 2020-05-16 13:49:06 -06:00
parent 89d9a2e60d
commit 04bd5f1f63
3 changed files with 72 additions and 74 deletions

View File

@ -1,10 +1,18 @@
# U^2-Net # U^2-Net
The code for our newly accepted paper in Pattern Recognition 2020: The code for our newly accepted paper in Pattern Recognition 2020:
## U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection, [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/). ## [U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author), [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/).
__Contact__: xuebin[at]ualberta[dot]ca __Contact__: xuebin[at]ualberta[dot]ca
## News
(2020-May-16) The official paper of U^2-Net (U square net) [PDF in elsevier](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author) is now available. If you are not able to access that, please feel free to drop me an email.
(2020-May-16) We fix the upsampling issue of the network. Now, the model should be able to handle *arbitrary* input size. (Tips: This modification is to facilitate the retraining of U^2-Net on your own datasets. When using our pre-trained model on SOD datasets, please keep the input size as 320x32 to guarantee the performance.)
(2020-May-16) We highly appreciate Cyril Diagne for building this fantastic AR project: [AR Copy and Paste](https://github.com/cyrildiagne/ar-cutpaste) using our [BASNet](https://github.com/NathanUA/BASNet) and U^2-Net. The [demo video](https://twitter.com/cyrildiagne/status/1256916982764646402) has achieved over 5M views, which is phenomenal and show us more probabilities of SOD.
## U^2-Net Results (173.6 MB) ## U^2-Net Results (173.6 MB)
![U^2-Net Results](figures/u2netqual.png) ![U^2-Net Results](figures/u2netqual.png)

View File

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils from torchvision import transforms, utils
from PIL import Image from PIL import Image
#==========================dataset load========================== #==========================dataset load==========================
class RescaleT(object): class RescaleT(object):
@ -50,6 +51,10 @@ class Rescale(object):
def __call__(self,sample): def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label'] imidx, image, label = sample['imidx'], sample['image'],sample['label']
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2] h, w = image.shape[:2]
if isinstance(self.output_size,int): if isinstance(self.output_size,int):

View File

@ -18,6 +18,14 @@ class REBNCONV(nn.Module):
return xout return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ### ### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module): class RSU7(nn.Module):#UNet07DRES(nn.Module):
@ -52,8 +60,6 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x): def forward(self,x):
hx = x hx = x
@ -79,19 +85,19 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module):
hx7 = self.rebnconv7(hx6) hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6up = self.upscore2(hx6d) hx6dup = _upsample_like(hx6d,hx5)
# print(hx6up.shape,hx5.shape)
hx5d = self.rebnconv5d(torch.cat((hx6up,hx5),1)) hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = self.upscore2(hx5d) hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d) hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d) hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d) hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -127,8 +133,6 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x): def forward(self,x):
hx = x hx = x
@ -153,16 +157,16 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module):
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = self.upscore2(hx5d) hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d) hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d) hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d) hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -194,8 +198,6 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x): def forward(self,x):
hx = x hx = x
@ -216,13 +218,13 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module):
hx5 = self.rebnconv5(hx4) hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = self.upscore2(hx4d) hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d) hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d) hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -250,8 +252,6 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module):
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self,x): def forward(self,x):
hx = x hx = x
@ -269,10 +269,10 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module):
hx4 = self.rebnconv4(hx3) hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = self.upscore2(hx3d) hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d) hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
@ -345,20 +345,14 @@ class U2NET(nn.Module):
self.stage2d = RSU6(256,32,64) self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64) self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,1,3,padding=1) self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,1,3,padding=1) self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(128,1,3,padding=1) self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
self.side4 = nn.Conv2d(256,1,3,padding=1) self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
self.side5 = nn.Conv2d(512,1,3,padding=1) self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
self.side6 = nn.Conv2d(512,1,3,padding=1) self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear') self.outconv = nn.Conv2d(6,out_ch,1)
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.outconv = nn.Conv2d(6,1,1)
def forward(self,x): def forward(self,x):
@ -372,8 +366,6 @@ class U2NET(nn.Module):
hx2 = self.stage2(hx) hx2 = self.stage2(hx)
hx = self.pool23(hx2) hx = self.pool23(hx2)
#stage 3 #stage 3
hx3 = self.stage3(hx) hx3 = self.stage3(hx)
hx = self.pool34(hx3) hx = self.pool34(hx3)
@ -388,20 +380,20 @@ class U2NET(nn.Module):
#stage 6 #stage 6
hx6 = self.stage6(hx) hx6 = self.stage6(hx)
hx6up = self.upscore2(hx6) hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder -------------------- #-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d) hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d) hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d) hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d) hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
@ -410,19 +402,19 @@ class U2NET(nn.Module):
d1 = self.side1(hx1d) d1 = self.side1(hx1d)
d2 = self.side2(hx2d) d2 = self.side2(hx2d)
d2 = self.upscore2(d2) d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d) d3 = self.side3(hx3d)
d3 = self.upscore3(d3) d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d) d4 = self.side4(hx4d)
d4 = self.upscore4(d4) d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d) d5 = self.side5(hx5d)
d5 = self.upscore5(d5) d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6) d6 = self.side6(hx6)
d6 = self.upscore6(d6) d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
@ -458,20 +450,14 @@ class U2NETP(nn.Module):
self.stage2d = RSU6(128,16,64) self.stage2d = RSU6(128,16,64)
self.stage1d = RSU7(128,16,64) self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2d(64,1,3,padding=1) self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,1,3,padding=1) self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(64,1,3,padding=1) self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
self.side4 = nn.Conv2d(64,1,3,padding=1) self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
self.side5 = nn.Conv2d(64,1,3,padding=1) self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
self.side6 = nn.Conv2d(64,1,3,padding=1) self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear') self.outconv = nn.Conv2d(6,out_ch,1)
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.outconv = nn.Conv2d(6,1,1)
def forward(self,x): def forward(self,x):
@ -499,20 +485,20 @@ class U2NETP(nn.Module):
#stage 6 #stage 6
hx6 = self.stage6(hx) hx6 = self.stage6(hx)
hx6up = self.upscore2(hx6) hx6up = _upsample_like(hx6,hx5)
#decoder #decoder
hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = self.upscore2(hx5d) hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = self.upscore2(hx4d) hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = self.upscore2(hx3d) hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = self.upscore2(hx2d) hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
@ -521,21 +507,20 @@ class U2NETP(nn.Module):
d1 = self.side1(hx1d) d1 = self.side1(hx1d)
d2 = self.side2(hx2d) d2 = self.side2(hx2d)
d2 = self.upscore2(d2) d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d) d3 = self.side3(hx3d)
d3 = self.upscore3(d3) d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d) d4 = self.side4(hx4d)
d4 = self.upscore4(d4) d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d) d5 = self.side5(hx5d)
d5 = self.upscore5(d5) d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6) d6 = self.side6(hx6)
d6 = self.upscore6(d6) d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
# d00 = d0 + self.refconv(d0)
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6) return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)