Merge pull request #8 from szerintedmi/resolve_0-dim_tensor_warnings

Resolve invalid index of a 0-dim tensor warnings
This commit is contained in:
Xuebin Qin 2021-01-23 17:24:56 -07:00 committed by GitHub
commit 3ef086adfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -84,9 +84,12 @@ def main():
elif(model_name=='u2netp'): elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB") print("...load U2NEP---4.7 MB")
net = U2NETP(3,1) net = U2NETP(3,1)
net.load_state_dict(torch.load(model_dir))
if torch.cuda.is_available(): if torch.cuda.is_available():
net.load_state_dict(torch.load(model_dir))
net.cuda() net.cuda()
else:
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval() net.eval()
# --------- 4. inference for each image --------- # --------- 4. inference for each image ---------

View File

@ -38,7 +38,7 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss6 = bce_loss(d6,labels_v) loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0])) print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))
return loss0, loss return loss0, loss
@ -144,8 +144,8 @@ for epoch in range(0, epoch_num):
optimizer.step() optimizer.step()
# # print statistics # # print statistics
running_loss += loss.data[0] running_loss += loss.data.item()
running_tar_loss += loss2.data[0] running_tar_loss += loss2.data.item()
# del temporary outputs and loss # del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss del d0, d1, d2, d3, d4, d5, d6, loss2, loss