added cpu support

This commit is contained in:
adakoda 2020-12-24 10:33:17 +09:00
parent b5464242d8
commit 7d141803a7
3 changed files with 12 additions and 3 deletions

View File

@ -151,7 +151,10 @@ def main():
# load u2net_portrait model
net = U2NET(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
else:
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
if torch.cuda.is_available():
net.cuda()
net.eval()

View File

@ -84,7 +84,10 @@ def main():
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
else:
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
if torch.cuda.is_available():
net.cuda()
net.eval()

View File

@ -84,7 +84,10 @@ def main():
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
else:
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
if torch.cuda.is_available():
net.cuda()
net.eval()