diff --git a/u2net_portrait_test.py b/u2net_portrait_test.py index bc4c665..7e103dd 100644 --- a/u2net_portrait_test.py +++ b/u2net_portrait_test.py @@ -84,10 +84,7 @@ def main(): print("...load U2NET---173.6 MB") net = U2NET(3,1) - if torch.cuda.is_available(): - net.load_state_dict(torch.load(model_dir)) - else: - net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) + net.load_state_dict(torch.load(model_dir)) if torch.cuda.is_available(): net.cuda() net.eval()