diff --git a/README.md b/README.md index 7c7b700..7715005 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,18 @@ # U^2-Net The code for our newly accepted paper in Pattern Recognition 2020: -## U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection, [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/). +## [U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author), [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/). __Contact__: xuebin[at]ualberta[dot]ca +## News + +(2020-May-16) The official paper of U^2-Net (U square net) [PDF in elsevier](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author) is now available. If you are not able to access that, please feel free to drop me an email. + +(2020-May-16) We fix the upsampling issue of the network. Now, the model should be able to handle *arbitrary* input size. (Tips: This modification is to facilitate the retraining of U^2-Net on your own datasets. When using our pre-trained model on SOD datasets, please keep the input size as 320x32 to guarantee the performance.) + +(2020-May-16) We highly appreciate Cyril Diagne for building this fantastic AR project: [AR Copy and Paste](https://github.com/cyrildiagne/ar-cutpaste) using our [BASNet](https://github.com/NathanUA/BASNet) and U^2-Net. The [demo video](https://twitter.com/cyrildiagne/status/1256916982764646402) has achieved over 5M views, which is phenomenal and show us more probabilities of SOD. + ## U^2-Net Results (173.6 MB) ![U^2-Net Results](figures/u2netqual.png) diff --git a/data_loader.py b/data_loader.py index 5a54b89..4062072 100644 --- a/data_loader.py +++ b/data_loader.py @@ -10,6 +10,7 @@ import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from PIL import Image + #==========================dataset load========================== class RescaleT(object): @@ -50,6 +51,10 @@ class Rescale(object): def __call__(self,sample): imidx, image, label = sample['imidx'], sample['image'],sample['label'] + if random.random() >= 0.5: + image = image[::-1] + label = label[::-1] + h, w = image.shape[:2] if isinstance(self.output_size,int): diff --git a/model/u2net.py b/model/u2net.py index dac437f..ece59e0 100644 --- a/model/u2net.py +++ b/model/u2net.py @@ -18,6 +18,14 @@ class REBNCONV(nn.Module): return xout +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src,tar): + + src = F.upsample(src,size=tar.shape[2:],mode='bilinear') + + return src + + ### RSU-7 ### class RSU7(nn.Module):#UNet07DRES(nn.Module): @@ -52,8 +60,6 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module): self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') - def forward(self,x): hx = x @@ -79,19 +85,19 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module): hx7 = self.rebnconv7(hx6) hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) - hx6up = self.upscore2(hx6d) - # print(hx6up.shape,hx5.shape) - hx5d = self.rebnconv5d(torch.cat((hx6up,hx5),1)) - hx5dup = self.upscore2(hx5d) + hx6dup = _upsample_like(hx6d,hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) + hx5dup = _upsample_like(hx5d,hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) - hx4dup = self.upscore2(hx4d) + hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) - hx3dup = self.upscore2(hx3d) + hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) - hx2dup = self.upscore2(hx2d) + hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) @@ -127,8 +133,6 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module): self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') - def forward(self,x): hx = x @@ -153,16 +157,16 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module): hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) - hx5dup = self.upscore2(hx5d) + hx5dup = _upsample_like(hx5d,hx4) hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) - hx4dup = self.upscore2(hx4d) + hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) - hx3dup = self.upscore2(hx3d) + hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) - hx2dup = self.upscore2(hx2d) + hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) @@ -194,8 +198,6 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module): self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') - def forward(self,x): hx = x @@ -216,13 +218,13 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module): hx5 = self.rebnconv5(hx4) hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) - hx4dup = self.upscore2(hx4d) + hx4dup = _upsample_like(hx4d,hx3) hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) - hx3dup = self.upscore2(hx3d) + hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) - hx2dup = self.upscore2(hx2d) + hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) @@ -250,8 +252,6 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module): self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') - def forward(self,x): hx = x @@ -269,10 +269,10 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module): hx4 = self.rebnconv4(hx3) hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) - hx3dup = self.upscore2(hx3d) + hx3dup = _upsample_like(hx3d,hx2) hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) - hx2dup = self.upscore2(hx2d) + hx2dup = _upsample_like(hx2d,hx1) hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) @@ -345,20 +345,14 @@ class U2NET(nn.Module): self.stage2d = RSU6(256,32,64) self.stage1d = RSU7(128,16,64) - self.side1 = nn.Conv2d(64,1,3,padding=1) - self.side2 = nn.Conv2d(64,1,3,padding=1) - self.side3 = nn.Conv2d(128,1,3,padding=1) - self.side4 = nn.Conv2d(256,1,3,padding=1) - self.side5 = nn.Conv2d(512,1,3,padding=1) - self.side6 = nn.Conv2d(512,1,3,padding=1) + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(128,out_ch,3,padding=1) + self.side4 = nn.Conv2d(256,out_ch,3,padding=1) + self.side5 = nn.Conv2d(512,out_ch,3,padding=1) + self.side6 = nn.Conv2d(512,out_ch,3,padding=1) - self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear') - self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear') - self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear') - self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear') - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') - - self.outconv = nn.Conv2d(6,1,1) + self.outconv = nn.Conv2d(6,out_ch,1) def forward(self,x): @@ -372,8 +366,6 @@ class U2NET(nn.Module): hx2 = self.stage2(hx) hx = self.pool23(hx2) - - #stage 3 hx3 = self.stage3(hx) hx = self.pool34(hx3) @@ -388,20 +380,20 @@ class U2NET(nn.Module): #stage 6 hx6 = self.stage6(hx) - hx6up = self.upscore2(hx6) + hx6up = _upsample_like(hx6,hx5) #-------------------- decoder -------------------- hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) - hx5dup = self.upscore2(hx5d) + hx5dup = _upsample_like(hx5d,hx4) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) - hx4dup = self.upscore2(hx4d) + hx4dup = _upsample_like(hx4d,hx3) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) - hx3dup = self.upscore2(hx3d) + hx3dup = _upsample_like(hx3d,hx2) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) - hx2dup = self.upscore2(hx2d) + hx2dup = _upsample_like(hx2d,hx1) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) @@ -410,19 +402,19 @@ class U2NET(nn.Module): d1 = self.side1(hx1d) d2 = self.side2(hx2d) - d2 = self.upscore2(d2) + d2 = _upsample_like(d2,d1) d3 = self.side3(hx3d) - d3 = self.upscore3(d3) + d3 = _upsample_like(d3,d1) d4 = self.side4(hx4d) - d4 = self.upscore4(d4) + d4 = _upsample_like(d4,d1) d5 = self.side5(hx5d) - d5 = self.upscore5(d5) + d5 = _upsample_like(d5,d1) d6 = self.side6(hx6) - d6 = self.upscore6(d6) + d6 = _upsample_like(d6,d1) d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) @@ -458,20 +450,14 @@ class U2NETP(nn.Module): self.stage2d = RSU6(128,16,64) self.stage1d = RSU7(128,16,64) - self.side1 = nn.Conv2d(64,1,3,padding=1) - self.side2 = nn.Conv2d(64,1,3,padding=1) - self.side3 = nn.Conv2d(64,1,3,padding=1) - self.side4 = nn.Conv2d(64,1,3,padding=1) - self.side5 = nn.Conv2d(64,1,3,padding=1) - self.side6 = nn.Conv2d(64,1,3,padding=1) + self.side1 = nn.Conv2d(64,out_ch,3,padding=1) + self.side2 = nn.Conv2d(64,out_ch,3,padding=1) + self.side3 = nn.Conv2d(64,out_ch,3,padding=1) + self.side4 = nn.Conv2d(64,out_ch,3,padding=1) + self.side5 = nn.Conv2d(64,out_ch,3,padding=1) + self.side6 = nn.Conv2d(64,out_ch,3,padding=1) - self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear') - self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear') - self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear') - self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear') - self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') - - self.outconv = nn.Conv2d(6,1,1) + self.outconv = nn.Conv2d(6,out_ch,1) def forward(self,x): @@ -499,20 +485,20 @@ class U2NETP(nn.Module): #stage 6 hx6 = self.stage6(hx) - hx6up = self.upscore2(hx6) + hx6up = _upsample_like(hx6,hx5) #decoder hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) - hx5dup = self.upscore2(hx5d) + hx5dup = _upsample_like(hx5d,hx4) hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) - hx4dup = self.upscore2(hx4d) + hx4dup = _upsample_like(hx4d,hx3) hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) - hx3dup = self.upscore2(hx3d) + hx3dup = _upsample_like(hx3d,hx2) hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) - hx2dup = self.upscore2(hx2d) + hx2dup = _upsample_like(hx2d,hx1) hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) @@ -521,21 +507,20 @@ class U2NETP(nn.Module): d1 = self.side1(hx1d) d2 = self.side2(hx2d) - d2 = self.upscore2(d2) + d2 = _upsample_like(d2,d1) d3 = self.side3(hx3d) - d3 = self.upscore3(d3) + d3 = _upsample_like(d3,d1) d4 = self.side4(hx4d) - d4 = self.upscore4(d4) + d4 = _upsample_like(d4,d1) d5 = self.side5(hx5d) - d5 = self.upscore5(d5) + d5 = _upsample_like(d5,d1) d6 = self.side6(hx6) - d6 = self.upscore6(d6) + d6 = _upsample_like(d6,d1) d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) - # d00 = d0 + self.refconv(d0) return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)