diff --git a/u2net_test.py b/u2net_test.py index 8b501c6..02f0300 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -37,7 +37,7 @@ def save_output(image_name,pred,d_dir): predict_np = predict.cpu().data.numpy() im = Image.fromarray(predict_np*255).convert('RGB') - img_name = image_name.split("/")[-1] + img_name = image_name.split(os.sep)[-1] image = io.imread(image_name) imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) @@ -57,11 +57,12 @@ def main(): model_name='u2net'#u2netp - image_dir = './test_data/test_images/' - prediction_dir = './test_data/' + model_name + '_results/' - model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth' - img_name_list = glob.glob(image_dir + '*') + image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images') + prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep) + model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth') + + img_name_list = glob.glob(image_dir + os.sep + '*') print(img_name_list) # --------- 2. dataloader --------- @@ -91,7 +92,7 @@ def main(): # --------- 4. inference for each image --------- for i_test, data_test in enumerate(test_salobj_dataloader): - print("inferencing:",img_name_list[i_test].split("/")[-1]) + print("inferencing:",img_name_list[i_test].split(os.sep)[-1]) inputs_test = data_test['image'] inputs_test = inputs_test.type(torch.FloatTensor) @@ -108,6 +109,8 @@ def main(): pred = normPRED(pred) # save results to test_results folder + if not os.path.exists(prediction_dir): + os.makedirs(prediction_dir, exist_ok=True) save_output(img_name_list[i_test],pred,prediction_dir) del d1,d2,d3,d4,d5,d6,d7