From 21bc4de1f3b6f08c1c52c7298dba33aa3fad2541 Mon Sep 17 00:00:00 2001 From: adakoda Date: Thu, 24 Dec 2020 11:03:29 +0900 Subject: [PATCH] Update u2net_portrait_demo.py --- u2net_portrait_demo.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/u2net_portrait_demo.py b/u2net_portrait_demo.py index 36ac841..d3e2e0c 100644 --- a/u2net_portrait_demo.py +++ b/u2net_portrait_demo.py @@ -151,10 +151,7 @@ def main(): # load u2net_portrait model net = U2NET(3,1) - if torch.cuda.is_available(): - net.load_state_dict(torch.load(model_dir)) - else: - net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) + net.load_state_dict(torch.load(model_dir)) if torch.cuda.is_available(): net.cuda() net.eval()