mirror of
https://git.mirrors.martin98.com/https://github.com/xuebinqin/U-2-Net
synced 2025-08-16 03:06:02 +08:00
Merge pull request #8 from szerintedmi/resolve_0-dim_tensor_warnings
Resolve invalid index of a 0-dim tensor warnings
This commit is contained in:
commit
3ef086adfb
@ -84,9 +84,12 @@ def main():
|
||||
elif(model_name=='u2netp'):
|
||||
print("...load U2NEP---4.7 MB")
|
||||
net = U2NETP(3,1)
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
net.load_state_dict(torch.load(model_dir))
|
||||
net.cuda()
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
|
||||
net.eval()
|
||||
|
||||
# --------- 4. inference for each image ---------
|
||||
|
@ -38,7 +38,7 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
|
||||
loss6 = bce_loss(d6,labels_v)
|
||||
|
||||
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
|
||||
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0]))
|
||||
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))
|
||||
|
||||
return loss0, loss
|
||||
|
||||
@ -144,8 +144,8 @@ for epoch in range(0, epoch_num):
|
||||
optimizer.step()
|
||||
|
||||
# # print statistics
|
||||
running_loss += loss.data[0]
|
||||
running_tar_loss += loss2.data[0]
|
||||
running_loss += loss.data.item()
|
||||
running_tar_loss += loss2.data.item()
|
||||
|
||||
# del temporary outputs and loss
|
||||
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user