diff --git a/u2net_test.py b/u2net_test.py index 02f0300..3460b56 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -84,9 +84,12 @@ def main(): elif(model_name=='u2netp'): print("...load U2NEP---4.7 MB") net = U2NETP(3,1) - net.load_state_dict(torch.load(model_dir)) + if torch.cuda.is_available(): + net.load_state_dict(torch.load(model_dir)) net.cuda() + else: + net.load_state_dict(torch.load(model_dir, map_location='cpu')) net.eval() # --------- 4. inference for each image --------- diff --git a/u2net_train.py b/u2net_train.py index 24466d4..43d9a12 100644 --- a/u2net_train.py +++ b/u2net_train.py @@ -38,7 +38,7 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v): loss6 = bce_loss(d6,labels_v) loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 - print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0])) + print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item())) return loss0, loss @@ -144,8 +144,8 @@ for epoch in range(0, epoch_num): optimizer.step() # # print statistics - running_loss += loss.data[0] - running_tar_loss += loss2.data[0] + running_loss += loss.data.item() + running_tar_loss += loss2.data.item() # del temporary outputs and loss del d0, d1, d2, d3, d4, d5, d6, loss2, loss