From 1b247c5419c27cac6fd2595e4a4287e0888b4503 Mon Sep 17 00:00:00 2001 From: szerintedmi Date: Mon, 11 May 2020 12:01:26 +0100 Subject: [PATCH 1/3] load model to cpu when no cuda --- u2net_test.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/u2net_test.py b/u2net_test.py index 8b501c6..504ca55 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -83,9 +83,12 @@ def main(): elif(model_name=='u2netp'): print("...load U2NEP---4.7 MB") net = U2NETP(3,1) - net.load_state_dict(torch.load(model_dir)) + if torch.cuda.is_available(): + net.load_state_dict(torch.load(model_dir)) net.cuda() + else: + net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) net.eval() # --------- 4. inference for each image --------- From 4031385d809406d54779d6d934d99fd2c5d988b8 Mon Sep 17 00:00:00 2001 From: szerintedmi Date: Mon, 11 May 2020 18:10:21 +0100 Subject: [PATCH 2/3] map location to work on torch < 4.0.1 --- u2net_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/u2net_test.py b/u2net_test.py index 504ca55..8c33886 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -88,7 +88,7 @@ def main(): net.load_state_dict(torch.load(model_dir)) net.cuda() else: - net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) + net.load_state_dict(torch.load(model_dir, map_location='cpu')) net.eval() # --------- 4. inference for each image --------- From 7c6262333774147738d6466202ccee47a46c13a2 Mon Sep 17 00:00:00 2001 From: szerintedmi Date: Mon, 11 May 2020 18:15:14 +0100 Subject: [PATCH 3/3] replace tensor[idx] ref with tensor.item() --- u2net_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/u2net_train.py b/u2net_train.py index 43d35a9..7ab61a0 100644 --- a/u2net_train.py +++ b/u2net_train.py @@ -37,7 +37,7 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): loss6 = bce_loss(d6,labels_v) loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 - print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0])) + print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item())) return loss0, loss @@ -143,8 +143,8 @@ for epoch in range(0, epoch_num): optimizer.step() # # print statistics - running_loss += loss.data[0] - running_tar_loss += loss2.data[0] + running_loss += loss.data.item() + running_tar_loss += loss2.data.item() # del temporary outputs and loss del d0, d1, d2, d3, d4, d5, d6, loss2, loss