protrait
21
README.md
@ -14,6 +14,26 @@ __Contact__: xuebin[at]ualberta[dot]ca
|
|||||||
|
|
||||||
## Updates !!!
|
## Updates !!!
|
||||||
|
|
||||||
|
**(2020-Nov-21)** We found a interesting application of U^2-Net for [**human protrait drawing**](https://www.pythonf.cn/read/141098). Therefore, we trained another model for this task based on the [**APDrawingGAN dataset**](https://github.com/yiranran/APDrawingGAN).
|
||||||
|
|
||||||
|
[!Sample Results: Kids](figures/portrait_kids.png)
|
||||||
|
|
||||||
|
[!Sample Results: Ladies](figures/portrait_ladies.png)
|
||||||
|
|
||||||
|
[!Sample Results: Men](figures/portrait_men.png)
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
1. Clone this repo to local
|
||||||
|
```
|
||||||
|
git clone https://github.com/NathanUA/U-2-Net.git
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Download the [**U2net_portrait.pth**](https://drive.google.com/file/d/1IG3HdpcRiDoWNookbncQjeaPN28t90yW/view?usp=sharing) model and put it into the directory: ```./saved_models/u2net_portrait/```.
|
||||||
|
|
||||||
|
3. Run on the testing set. Download the train and test set from [**APDrawingGAN**](https://github.com/yiranran/APDrawingGAN). These images and their ground truth are stitched side-by-side (512x1024). You need to split each of these images into two 512x512 images and put them into ```./test_data/test_portrait_images/portrait_im/``. You can also download the split testing set [**here**](https://drive.google.com/file/d/1NkTsDDN8VO-JVik6VxXyV-3l2eo29KCk/view?usp=sharing). Running the inference with command ```python u2net_portrait_test.py``` will ouptut the results into ```./test_data/test_portrait_images/portrait_results```.
|
||||||
|
|
||||||
|
4. Run on your own dataset. Prepare your images and put them into ```./test_data/test_portrait_images/your_portrait_im/```. Run the prediction by command ```python u2net_portrait_demo.py``` The difference of the code for runing the test set and your own dataset is that the testing set are normalized and cropped to 512x512 for including only head of human, while your own dataset may varies with different resolution and contents. To achieve stable results, we added a simple [**face detection**](https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_objdetect/py_face_detection/py_face_detection.html) step before the portrait generation in ```u2net_portrait_demo.py```. Therefore, the code will detect the biggest face from the given image and then crop, pad and resize the ROI to 512x512 for feeding to the network.
|
||||||
|
|
||||||
**(2020-Sep-13)** Our U^2-Net based model is the **6th** in [**MICCAI 2020 Thyroid Nodule Segmentation Challenge**](https://tn-scui2020.grand-challenge.org/Resultannouncement/).
|
**(2020-Sep-13)** Our U^2-Net based model is the **6th** in [**MICCAI 2020 Thyroid Nodule Segmentation Challenge**](https://tn-scui2020.grand-challenge.org/Resultannouncement/).
|
||||||
|
|
||||||
**(2020-May-18)** The official paper of our **U^2-Net (U square net)** ([**PDF in elsevier**(free until July 5 2020)](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author), [**PDF in arxiv**](http://arxiv.org/abs/2005.09007)) is now available. If you are not able to access that, please feel free to drop me an email.
|
**(2020-May-18)** The official paper of our **U^2-Net (U square net)** ([**PDF in elsevier**(free until July 5 2020)](https://www.sciencedirect.com/science/article/pii/S0031320320302077?dgcid=author), [**PDF in arxiv**](http://arxiv.org/abs/2005.09007)) is now available. If you are not able to access that, please feel free to drop me an email.
|
||||||
@ -34,6 +54,7 @@ __Contact__: xuebin[at]ualberta[dot]ca
|
|||||||
Python 3.6
|
Python 3.6
|
||||||
numpy 1.15.2
|
numpy 1.15.2
|
||||||
scikit-image 0.14.0
|
scikit-image 0.14.0
|
||||||
|
python-opencv
|
||||||
PIL 5.2.0
|
PIL 5.2.0
|
||||||
PyTorch 0.4.0
|
PyTorch 0.4.0
|
||||||
torchvision 0.2.1
|
torchvision 0.2.1
|
||||||
|
BIN
__pycache__/data_loader.cpython-37.pyc
Normal file
BIN
figures/portrait_kids.png
Normal file
After Width: | Height: | Size: 802 KiB |
BIN
figures/portrait_ladies.png
Normal file
After Width: | Height: | Size: 803 KiB |
BIN
figures/portrait_men.png
Normal file
After Width: | Height: | Size: 666 KiB |
BIN
model/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
model/__pycache__/u2net.cpython-37.pyc
Normal file
33314
saved_models/face_detection_cv2/haarcascade_frontalface_default.xml
Normal file
BIN
test_data/test_portrait_images/portrait_im/img_1585.png
Normal file
After Width: | Height: | Size: 325 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1588.png
Normal file
After Width: | Height: | Size: 327 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1594.png
Normal file
After Width: | Height: | Size: 365 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1616.png
Normal file
After Width: | Height: | Size: 340 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1695.png
Normal file
After Width: | Height: | Size: 420 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1696.png
Normal file
After Width: | Height: | Size: 459 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1771.png
Normal file
After Width: | Height: | Size: 271 KiB |
BIN
test_data/test_portrait_images/portrait_im/img_1859.png
Normal file
After Width: | Height: | Size: 337 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1585.png
Normal file
After Width: | Height: | Size: 200 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1588.png
Normal file
After Width: | Height: | Size: 168 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1594.png
Normal file
After Width: | Height: | Size: 154 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1616.png
Normal file
After Width: | Height: | Size: 183 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1695.png
Normal file
After Width: | Height: | Size: 248 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1696.png
Normal file
After Width: | Height: | Size: 170 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1771.png
Normal file
After Width: | Height: | Size: 113 KiB |
BIN
test_data/test_portrait_images/portrait_results/img_1859.png
Normal file
After Width: | Height: | Size: 187 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/GalGadot.jpg
Normal file
After Width: | Height: | Size: 67 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/guliNazha3.jpg
Normal file
After Width: | Height: | Size: 238 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/kid1.jpg
Normal file
After Width: | Height: | Size: 88 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/kid2.jpg
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/kid3.jpg
Normal file
After Width: | Height: | Size: 62 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/man.jpg
Normal file
After Width: | Height: | Size: 27 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/man2.jpg
Normal file
After Width: | Height: | Size: 64 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/man4.jpg
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/man5.jpg
Normal file
After Width: | Height: | Size: 119 KiB |
BIN
test_data/test_portrait_images/your_portrait_im/smile.jpg
Normal file
After Width: | Height: | Size: 29 KiB |
After Width: | Height: | Size: 92 KiB |
After Width: | Height: | Size: 79 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/kid1.png
Normal file
After Width: | Height: | Size: 98 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/kid2.png
Normal file
After Width: | Height: | Size: 82 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/kid3.png
Normal file
After Width: | Height: | Size: 81 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/man.png
Normal file
After Width: | Height: | Size: 63 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/man2.png
Normal file
After Width: | Height: | Size: 81 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/man4.png
Normal file
After Width: | Height: | Size: 85 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/man5.png
Normal file
After Width: | Height: | Size: 99 KiB |
BIN
test_data/test_portrait_images/your_portrait_results/smile.png
Normal file
After Width: | Height: | Size: 62 KiB |
167
u2net_portrait_demo.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from model import U2NET
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import numpy as np
|
||||||
|
from glob import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
def detect_single_face(face_cascade,img):
|
||||||
|
# Convert into grayscale
|
||||||
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
# Detect faces
|
||||||
|
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
|
||||||
|
if(len(faces)==0):
|
||||||
|
print("Warming: no face detection, the portrait u2net will run on the whole image!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# filter to keep the largest face
|
||||||
|
wh = 0
|
||||||
|
idx = 0
|
||||||
|
for i in range(0,len(faces)):
|
||||||
|
(x,y,w,h) = faces[i]
|
||||||
|
if(wh<w*h):
|
||||||
|
idx = i
|
||||||
|
wh = w*h
|
||||||
|
|
||||||
|
return faces[idx]
|
||||||
|
|
||||||
|
# crop, pad and resize face region to 512x512 resolution
|
||||||
|
def crop_face(img, face):
|
||||||
|
|
||||||
|
# no face detected, return the whole image and the inference will run on the whole image
|
||||||
|
if(face is None):
|
||||||
|
return img
|
||||||
|
(x, y, w, h) = face
|
||||||
|
|
||||||
|
height,width = img.shape[0:2]
|
||||||
|
|
||||||
|
# crop the face with a bigger bbox
|
||||||
|
hmw = h - w
|
||||||
|
hpad = int(h/2)+1
|
||||||
|
wpad = int(w/2)+1
|
||||||
|
|
||||||
|
l,r,t,b = 0,0,0,0
|
||||||
|
left = x-wpad
|
||||||
|
if(left<0):
|
||||||
|
left = 0
|
||||||
|
l = wpad-x
|
||||||
|
right = x+w+wpad
|
||||||
|
if(right>width):
|
||||||
|
right = width
|
||||||
|
r = right-width
|
||||||
|
top = y - hpad
|
||||||
|
if(top<0):
|
||||||
|
top = 0
|
||||||
|
t = hpad-y
|
||||||
|
bottom = y+h+int(hpad*0.5)
|
||||||
|
if(bottom>height):
|
||||||
|
bottom = height
|
||||||
|
b = bottom-height
|
||||||
|
|
||||||
|
im_face = img[top:bottom,left:right]
|
||||||
|
if(len(im_face.shape)==2):
|
||||||
|
im_face = np.repeat(im_face[:,:,np.newaxis],(1,1,3))
|
||||||
|
|
||||||
|
im_face = np.pad(im_face,((t,b),(l,r),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
|
||||||
|
|
||||||
|
# pad to achieve image with square shape for avoding face deformation after resizing
|
||||||
|
hf,wf = im_face.shape[0:2]
|
||||||
|
if(hf-2>wf):
|
||||||
|
wfp = int((hf-wf)/2)
|
||||||
|
im_face = np.pad(im_face,((0,0),(wfp,wfp),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
|
||||||
|
elif(wf-2>hf):
|
||||||
|
hfp = int((wf-hf)/2)
|
||||||
|
im_face = np.pad(im_face,((hfp,hfp),(0,0),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
|
||||||
|
|
||||||
|
# resize to have 512x512 resolution
|
||||||
|
im_face = cv2.resize(im_face, (512,512), interpolation = cv2.INTER_AREA)
|
||||||
|
|
||||||
|
return im_face
|
||||||
|
|
||||||
|
def normPRED(d):
|
||||||
|
ma = torch.max(d)
|
||||||
|
mi = torch.min(d)
|
||||||
|
|
||||||
|
dn = (d-mi)/(ma-mi)
|
||||||
|
|
||||||
|
return dn
|
||||||
|
|
||||||
|
def inference(net,input):
|
||||||
|
|
||||||
|
# normalize the input
|
||||||
|
tmpImg = np.zeros((input.shape[0],input.shape[1],3))
|
||||||
|
input = input/np.max(input)
|
||||||
|
|
||||||
|
tmpImg[:,:,0] = (input[:,:,2]-0.406)/0.225
|
||||||
|
tmpImg[:,:,1] = (input[:,:,1]-0.456)/0.224
|
||||||
|
tmpImg[:,:,2] = (input[:,:,0]-0.485)/0.229
|
||||||
|
|
||||||
|
# convert BGR to RGB
|
||||||
|
tmpImg = tmpImg.transpose((2, 0, 1))
|
||||||
|
tmpImg = tmpImg[np.newaxis,:,:,:]
|
||||||
|
tmpImg = torch.from_numpy(tmpImg)
|
||||||
|
|
||||||
|
# convert numpy array to torch tensor
|
||||||
|
tmpImg = tmpImg.type(torch.FloatTensor)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
tmpImg = Variable(tmpImg.cuda())
|
||||||
|
else:
|
||||||
|
tmpImg = Variable(tmpImg)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
d1,d2,d3,d4,d5,d6,d7= net(tmpImg)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
pred = 1.0 - d1[:,0,:,:]
|
||||||
|
pred = normPRED(pred)
|
||||||
|
|
||||||
|
# convert torch tensor to numpy array
|
||||||
|
pred = pred.squeeze()
|
||||||
|
pred = pred.cpu().data.numpy()
|
||||||
|
|
||||||
|
del d1,d2,d3,d4,d5,d6,d7
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
# get the image path list for inference
|
||||||
|
im_list = glob('./test_data/test_portrait_images/your_portrait_im/*')
|
||||||
|
print("Number of images: ",len(im_list))
|
||||||
|
# indicate the output directory
|
||||||
|
out_dir = './test_data/test_portrait_images/your_portrait_results'
|
||||||
|
if(not os.path.exists(out_dir)):
|
||||||
|
os.mkdir(out_dir)
|
||||||
|
|
||||||
|
# Load the cascade face detection model
|
||||||
|
face_cascade = cv2.CascadeClassifier('./saved_models/face_detection_cv2/haarcascade_frontalface_default.xml')
|
||||||
|
# u2net_portrait path
|
||||||
|
model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
|
||||||
|
|
||||||
|
# load u2net_portrait model
|
||||||
|
net = U2NET(3,1)
|
||||||
|
net.load_state_dict(torch.load(model_dir))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
# do the inference one-by-one
|
||||||
|
for i in range(0,len(im_list)):
|
||||||
|
print("--------------------------")
|
||||||
|
print("inferencing ", i, "/", len(im_list), im_list[i])
|
||||||
|
|
||||||
|
# load each image
|
||||||
|
img = cv2.imread(im_list[i])
|
||||||
|
height,width = img.shape[0:2]
|
||||||
|
face = detect_single_face(face_cascade,img)
|
||||||
|
im_face = crop_face(img, face)
|
||||||
|
im_portrait = inference(net,im_face)
|
||||||
|
|
||||||
|
# save the output
|
||||||
|
cv2.imwrite(out_dir+"/"+im_list[i].split('/')[-1][0:-4]+'.png',(im_portrait*255).astype(np.uint8))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
117
u2net_portrait_test.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import os
|
||||||
|
from skimage import io, transform
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchvision import transforms#, utils
|
||||||
|
# import torch.optim as optim
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import glob
|
||||||
|
|
||||||
|
from data_loader import RescaleT
|
||||||
|
from data_loader import ToTensor
|
||||||
|
from data_loader import ToTensorLab
|
||||||
|
from data_loader import SalObjDataset
|
||||||
|
|
||||||
|
from model import U2NET # full size version 173.6 MB
|
||||||
|
from model import U2NETP # small version u2net 4.7 MB
|
||||||
|
|
||||||
|
# normalize the predicted SOD probability map
|
||||||
|
def normPRED(d):
|
||||||
|
ma = torch.max(d)
|
||||||
|
mi = torch.min(d)
|
||||||
|
|
||||||
|
dn = (d-mi)/(ma-mi)
|
||||||
|
|
||||||
|
return dn
|
||||||
|
|
||||||
|
def save_output(image_name,pred,d_dir):
|
||||||
|
|
||||||
|
predict = pred
|
||||||
|
predict = predict.squeeze()
|
||||||
|
predict_np = predict.cpu().data.numpy()
|
||||||
|
|
||||||
|
im = Image.fromarray(predict_np*255).convert('RGB')
|
||||||
|
img_name = image_name.split(os.sep)[-1]
|
||||||
|
image = io.imread(image_name)
|
||||||
|
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
|
||||||
|
|
||||||
|
pb_np = np.array(imo)
|
||||||
|
|
||||||
|
aaa = img_name.split(".")
|
||||||
|
bbb = aaa[0:-1]
|
||||||
|
imidx = bbb[0]
|
||||||
|
for i in range(1,len(bbb)):
|
||||||
|
imidx = imidx + "." + bbb[i]
|
||||||
|
|
||||||
|
imo.save(d_dir+'/'+imidx+'.png')
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
# --------- 1. get image path and name ---------
|
||||||
|
model_name='u2net_portrait'#u2netp
|
||||||
|
|
||||||
|
|
||||||
|
image_dir = './test_data/test_portrait_images/portrait_im'
|
||||||
|
prediction_dir = './test_data/test_portrait_images/portrait_results'
|
||||||
|
if(not os.path.exists(prediction_dir)):
|
||||||
|
os.mkdir(prediction_dir)
|
||||||
|
|
||||||
|
model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'
|
||||||
|
|
||||||
|
img_name_list = glob.glob(image_dir+'/*')
|
||||||
|
print("Number of images: ", len(img_name_list))
|
||||||
|
|
||||||
|
# --------- 2. dataloader ---------
|
||||||
|
#1. dataloader
|
||||||
|
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
|
||||||
|
lbl_name_list = [],
|
||||||
|
transform=transforms.Compose([RescaleT(512),
|
||||||
|
ToTensorLab(flag=0)])
|
||||||
|
)
|
||||||
|
test_salobj_dataloader = DataLoader(test_salobj_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=1)
|
||||||
|
|
||||||
|
# --------- 3. model define ---------
|
||||||
|
|
||||||
|
print("...load U2NET---173.6 MB")
|
||||||
|
net = U2NET(3,1)
|
||||||
|
|
||||||
|
net.load_state_dict(torch.load(model_dir))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
# --------- 4. inference for each image ---------
|
||||||
|
for i_test, data_test in enumerate(test_salobj_dataloader):
|
||||||
|
|
||||||
|
print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
|
||||||
|
|
||||||
|
inputs_test = data_test['image']
|
||||||
|
inputs_test = inputs_test.type(torch.FloatTensor)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs_test = Variable(inputs_test.cuda())
|
||||||
|
else:
|
||||||
|
inputs_test = Variable(inputs_test)
|
||||||
|
|
||||||
|
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
pred = 1.0 - d1[:,0,:,:]
|
||||||
|
pred = normPRED(pred)
|
||||||
|
|
||||||
|
# save results to test_results folder
|
||||||
|
save_output(img_name_list[i_test],pred,prediction_dir)
|
||||||
|
|
||||||
|
del d1,d2,d3,d4,d5,d6,d7
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|