diff --git a/u2net_test.py b/u2net_test.py index 8b501c6..504ca55 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -83,9 +83,12 @@ def main(): elif(model_name=='u2netp'): print("...load U2NEP---4.7 MB") net = U2NETP(3,1) - net.load_state_dict(torch.load(model_dir)) + if torch.cuda.is_available(): + net.load_state_dict(torch.load(model_dir)) net.cuda() + else: + net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) net.eval() # --------- 4. inference for each image ---------