diff --git a/u2net_portrait_demo.py b/u2net_portrait_demo.py index 36ac841..d3e2e0c 100644 --- a/u2net_portrait_demo.py +++ b/u2net_portrait_demo.py @@ -151,10 +151,7 @@ def main(): # load u2net_portrait model net = U2NET(3,1) - if torch.cuda.is_available(): - net.load_state_dict(torch.load(model_dir)) - else: - net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) + net.load_state_dict(torch.load(model_dir)) if torch.cuda.is_available(): net.cuda() net.eval()