diff --git a/u2net_human_seg_test.py b/u2net_human_seg_test.py index 53c5e35..ade5fc0 100644 --- a/u2net_human_seg_test.py +++ b/u2net_human_seg_test.py @@ -56,7 +56,7 @@ def main(): model_name='u2net' - image_dir = os.path.join(os.getcwd(), 'test_data', 'test_human_images')#'test_human_images')#'test_portrait_images', 'your_portrait_im') + image_dir = os.path.join(os.getcwd(), 'test_data', 'test_human_images') prediction_dir = os.path.join(os.getcwd(), 'test_data', 'test_human_images' + '_results' + os.sep) model_dir = os.path.join(os.getcwd(), 'saved_models', model_name+'_human_seg', model_name + '_human_seg.pth')