mirror of
https://git.mirrors.martin98.com/https://github.com/xuebinqin/U-2-Net
synced 2025-08-05 20:16:10 +08:00
added cpu support
This commit is contained in:
parent
b5464242d8
commit
7d141803a7
@ -151,7 +151,10 @@ def main():
|
||||
|
||||
# load u2net_portrait model
|
||||
net = U2NET(3,1)
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
@ -84,7 +84,10 @@ def main():
|
||||
print("...load U2NET---173.6 MB")
|
||||
net = U2NET(3,1)
|
||||
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
@ -84,7 +84,10 @@ def main():
|
||||
elif(model_name=='u2netp'):
|
||||
print("...load U2NEP---4.7 MB")
|
||||
net = U2NETP(3,1)
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.eval()
|
||||
|
Loading…
x
Reference in New Issue
Block a user