mirror of
https://git.mirrors.martin98.com/https://github.com/xuebinqin/U-2-Net
synced 2025-08-14 03:06:02 +08:00
input size and readme
This commit is contained in:
parent
89d9a2e60d
commit
04bd5f1f63
10
README.md
10
README.md
@ -1,10 +1,18 @@
|
||||
# U^2-Net
|
||||
|
||||
The code for our newly accepted paper in Pattern Recognition 2020:
|
||||
## U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection, [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/).
|
||||
## [U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author), [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Masood Dehghan](https://sites.google.com/view/masooddehghan), [Osmar R. Zaiane](http://webdocs.cs.ualberta.ca/~zaiane/) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/).
|
||||
|
||||
__Contact__: xuebin[at]ualberta[dot]ca
|
||||
|
||||
## News
|
||||
|
||||
(2020-May-16) The official paper of U^2-Net (U square net) [PDF in elsevier](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author) is now available. If you are not able to access that, please feel free to drop me an email.
|
||||
|
||||
(2020-May-16) We fix the upsampling issue of the network. Now, the model should be able to handle *arbitrary* input size. (Tips: This modification is to facilitate the retraining of U^2-Net on your own datasets. When using our pre-trained model on SOD datasets, please keep the input size as 320x32 to guarantee the performance.)
|
||||
|
||||
(2020-May-16) We highly appreciate Cyril Diagne for building this fantastic AR project: [AR Copy and Paste](https://github.com/cyrildiagne/ar-cutpaste) using our [BASNet](https://github.com/NathanUA/BASNet) and U^2-Net. The [demo video](https://twitter.com/cyrildiagne/status/1256916982764646402) has achieved over 5M views, which is phenomenal and show us more probabilities of SOD.
|
||||
|
||||
## U^2-Net Results (173.6 MB)
|
||||
|
||||

|
||||
|
@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms, utils
|
||||
from PIL import Image
|
||||
|
||||
#==========================dataset load==========================
|
||||
class RescaleT(object):
|
||||
|
||||
@ -50,6 +51,10 @@ class Rescale(object):
|
||||
def __call__(self,sample):
|
||||
imidx, image, label = sample['imidx'], sample['image'],sample['label']
|
||||
|
||||
if random.random() >= 0.5:
|
||||
image = image[::-1]
|
||||
label = label[::-1]
|
||||
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if isinstance(self.output_size,int):
|
||||
|
131
model/u2net.py
131
model/u2net.py
@ -18,6 +18,14 @@ class REBNCONV(nn.Module):
|
||||
|
||||
return xout
|
||||
|
||||
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
||||
def _upsample_like(src,tar):
|
||||
|
||||
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
|
||||
|
||||
return src
|
||||
|
||||
|
||||
### RSU-7 ###
|
||||
class RSU7(nn.Module):#UNet07DRES(nn.Module):
|
||||
|
||||
@ -52,8 +60,6 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module):
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
@ -79,19 +85,19 @@ class RSU7(nn.Module):#UNet07DRES(nn.Module):
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
|
||||
hx6up = self.upscore2(hx6d)
|
||||
# print(hx6up.shape,hx5.shape)
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = self.upscore2(hx5d)
|
||||
hx6dup = _upsample_like(hx6d,hx5)
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = self.upscore2(hx4d)
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = self.upscore2(hx3d)
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = self.upscore2(hx2d)
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
@ -127,8 +133,6 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module):
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
@ -153,16 +157,16 @@ class RSU6(nn.Module):#UNet06DRES(nn.Module):
|
||||
|
||||
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
|
||||
hx5dup = self.upscore2(hx5d)
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = self.upscore2(hx4d)
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = self.upscore2(hx3d)
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = self.upscore2(hx2d)
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
@ -194,8 +198,6 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module):
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
@ -216,13 +218,13 @@ class RSU5(nn.Module):#UNet05DRES(nn.Module):
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
|
||||
hx4dup = self.upscore2(hx4d)
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = self.upscore2(hx3d)
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = self.upscore2(hx2d)
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
@ -250,8 +252,6 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module):
|
||||
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
|
||||
|
||||
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
hx = x
|
||||
@ -269,10 +269,10 @@ class RSU4(nn.Module):#UNet04DRES(nn.Module):
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
|
||||
hx3dup = self.upscore2(hx3d)
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = self.upscore2(hx2d)
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
@ -345,20 +345,14 @@ class U2NET(nn.Module):
|
||||
self.stage2d = RSU6(256,32,64)
|
||||
self.stage1d = RSU7(128,16,64)
|
||||
|
||||
self.side1 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side3 = nn.Conv2d(128,1,3,padding=1)
|
||||
self.side4 = nn.Conv2d(256,1,3,padding=1)
|
||||
self.side5 = nn.Conv2d(512,1,3,padding=1)
|
||||
self.side6 = nn.Conv2d(512,1,3,padding=1)
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
|
||||
|
||||
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
|
||||
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
|
||||
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
|
||||
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
|
||||
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
|
||||
self.outconv = nn.Conv2d(6,1,1)
|
||||
self.outconv = nn.Conv2d(6,out_ch,1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
@ -372,8 +366,6 @@ class U2NET(nn.Module):
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
|
||||
|
||||
|
||||
#stage 3
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
@ -388,20 +380,20 @@ class U2NET(nn.Module):
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = self.upscore2(hx6)
|
||||
hx6up = _upsample_like(hx6,hx5)
|
||||
|
||||
#-------------------- decoder --------------------
|
||||
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = self.upscore2(hx5d)
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = self.upscore2(hx4d)
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = self.upscore2(hx3d)
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = self.upscore2(hx2d)
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
@ -410,19 +402,19 @@ class U2NET(nn.Module):
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = self.upscore2(d2)
|
||||
d2 = _upsample_like(d2,d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = self.upscore3(d3)
|
||||
d3 = _upsample_like(d3,d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = self.upscore4(d4)
|
||||
d4 = _upsample_like(d4,d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = self.upscore5(d5)
|
||||
d5 = _upsample_like(d5,d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = self.upscore6(d6)
|
||||
d6 = _upsample_like(d6,d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
|
||||
@ -458,20 +450,14 @@ class U2NETP(nn.Module):
|
||||
self.stage2d = RSU6(128,16,64)
|
||||
self.stage1d = RSU7(128,16,64)
|
||||
|
||||
self.side1 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side3 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side4 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side5 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side6 = nn.Conv2d(64,1,3,padding=1)
|
||||
self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
|
||||
|
||||
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')
|
||||
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
|
||||
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
|
||||
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
|
||||
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
|
||||
|
||||
self.outconv = nn.Conv2d(6,1,1)
|
||||
self.outconv = nn.Conv2d(6,out_ch,1)
|
||||
|
||||
def forward(self,x):
|
||||
|
||||
@ -499,20 +485,20 @@ class U2NETP(nn.Module):
|
||||
|
||||
#stage 6
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = self.upscore2(hx6)
|
||||
hx6up = _upsample_like(hx6,hx5)
|
||||
|
||||
#decoder
|
||||
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
|
||||
hx5dup = self.upscore2(hx5d)
|
||||
hx5dup = _upsample_like(hx5d,hx4)
|
||||
|
||||
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
|
||||
hx4dup = self.upscore2(hx4d)
|
||||
hx4dup = _upsample_like(hx4d,hx3)
|
||||
|
||||
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
|
||||
hx3dup = self.upscore2(hx3d)
|
||||
hx3dup = _upsample_like(hx3d,hx2)
|
||||
|
||||
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
|
||||
hx2dup = self.upscore2(hx2d)
|
||||
hx2dup = _upsample_like(hx2d,hx1)
|
||||
|
||||
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
|
||||
|
||||
@ -521,21 +507,20 @@ class U2NETP(nn.Module):
|
||||
d1 = self.side1(hx1d)
|
||||
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = self.upscore2(d2)
|
||||
d2 = _upsample_like(d2,d1)
|
||||
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = self.upscore3(d3)
|
||||
d3 = _upsample_like(d3,d1)
|
||||
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = self.upscore4(d4)
|
||||
d4 = _upsample_like(d4,d1)
|
||||
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = self.upscore5(d5)
|
||||
d5 = _upsample_like(d5,d1)
|
||||
|
||||
d6 = self.side6(hx6)
|
||||
d6 = self.upscore6(d6)
|
||||
d6 = _upsample_like(d6,d1)
|
||||
|
||||
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
||||
# d00 = d0 + self.refconv(d0)
|
||||
|
||||
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
|
||||
|
Loading…
x
Reference in New Issue
Block a user