Update u2net_portrait_test.py

This commit is contained in:
adakoda 2020-12-24 11:12:40 +09:00 committed by GitHub
parent c92b64ba0e
commit 6b2d7be4fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -84,10 +84,7 @@ def main():
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
else:
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.cuda()
net.eval()