mirror of
https://git.mirrors.martin98.com/https://github.com/xuebinqin/U-2-Net
synced 2025-08-14 04:45:55 +08:00
Update u2net_test.py
Make code OS-independent.
This commit is contained in:
parent
58a69954b5
commit
9fe1d94ff2
@ -37,7 +37,7 @@ def save_output(image_name,pred,d_dir):
|
||||
predict_np = predict.cpu().data.numpy()
|
||||
|
||||
im = Image.fromarray(predict_np*255).convert('RGB')
|
||||
img_name = image_name.split("/")[-1]
|
||||
img_name = image_name.split(os.sep)[-1]
|
||||
image = io.imread(image_name)
|
||||
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
|
||||
|
||||
@ -57,11 +57,12 @@ def main():
|
||||
model_name='u2net'#u2netp
|
||||
|
||||
|
||||
image_dir = './test_data/test_images/'
|
||||
prediction_dir = './test_data/' + model_name + '_results/'
|
||||
model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth'
|
||||
|
||||
img_name_list = glob.glob(image_dir + '*')
|
||||
image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
|
||||
prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
|
||||
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
|
||||
|
||||
img_name_list = glob.glob(image_dir + os.sep + '*')
|
||||
print(img_name_list)
|
||||
|
||||
# --------- 2. dataloader ---------
|
||||
@ -91,7 +92,7 @@ def main():
|
||||
# --------- 4. inference for each image ---------
|
||||
for i_test, data_test in enumerate(test_salobj_dataloader):
|
||||
|
||||
print("inferencing:",img_name_list[i_test].split("/")[-1])
|
||||
print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
|
||||
|
||||
inputs_test = data_test['image']
|
||||
inputs_test = inputs_test.type(torch.FloatTensor)
|
||||
@ -108,6 +109,8 @@ def main():
|
||||
pred = normPRED(pred)
|
||||
|
||||
# save results to test_results folder
|
||||
if not os.path.exists(prediction_dir):
|
||||
os.makedirs(prediction_dir, exist_ok=True)
|
||||
save_output(img_name_list[i_test],pred,prediction_dir)
|
||||
|
||||
del d1,d2,d3,d4,d5,d6,d7
|
||||
|
Loading…
x
Reference in New Issue
Block a user