mirror of
https://git.mirrors.martin98.com/https://github.com/xuebinqin/U-2-Net
synced 2025-08-16 07:35:56 +08:00
Merge pull request #8 from szerintedmi/resolve_0-dim_tensor_warnings
Resolve invalid index of a 0-dim tensor warnings
This commit is contained in:
commit
3ef086adfb
@ -84,9 +84,12 @@ def main():
|
|||||||
elif(model_name=='u2netp'):
|
elif(model_name=='u2netp'):
|
||||||
print("...load U2NEP---4.7 MB")
|
print("...load U2NEP---4.7 MB")
|
||||||
net = U2NETP(3,1)
|
net = U2NETP(3,1)
|
||||||
net.load_state_dict(torch.load(model_dir))
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
net.load_state_dict(torch.load(model_dir))
|
||||||
net.cuda()
|
net.cuda()
|
||||||
|
else:
|
||||||
|
net.load_state_dict(torch.load(model_dir, map_location='cpu'))
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
# --------- 4. inference for each image ---------
|
# --------- 4. inference for each image ---------
|
||||||
|
@ -38,7 +38,7 @@ def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
|
|||||||
loss6 = bce_loss(d6,labels_v)
|
loss6 = bce_loss(d6,labels_v)
|
||||||
|
|
||||||
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
|
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
|
||||||
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0]))
|
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))
|
||||||
|
|
||||||
return loss0, loss
|
return loss0, loss
|
||||||
|
|
||||||
@ -144,8 +144,8 @@ for epoch in range(0, epoch_num):
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# # print statistics
|
# # print statistics
|
||||||
running_loss += loss.data[0]
|
running_loss += loss.data.item()
|
||||||
running_tar_loss += loss2.data[0]
|
running_tar_loss += loss2.data.item()
|
||||||
|
|
||||||
# del temporary outputs and loss
|
# del temporary outputs and loss
|
||||||
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
|
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user