mirror of
https://git.mirrors.martin98.com/https://github.com/xuebinqin/U-2-Net
synced 2025-08-15 15:45:59 +08:00
load model to cpu when no cuda
This commit is contained in:
parent
89d9a2e60d
commit
1b247c5419
@ -83,9 +83,12 @@ def main():
|
|||||||
elif(model_name=='u2netp'):
|
elif(model_name=='u2netp'):
|
||||||
print("...load U2NEP---4.7 MB")
|
print("...load U2NEP---4.7 MB")
|
||||||
net = U2NETP(3,1)
|
net = U2NETP(3,1)
|
||||||
net.load_state_dict(torch.load(model_dir))
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
net.load_state_dict(torch.load(model_dir))
|
||||||
net.cuda()
|
net.cuda()
|
||||||
|
else:
|
||||||
|
net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
# --------- 4. inference for each image ---------
|
# --------- 4. inference for each image ---------
|
||||||
|
Loading…
x
Reference in New Issue
Block a user