InPeerReview's picture
Upload 6 files
7ad9dfd verified
import torch
import torch.nn as nn
from torch import Tensor, einsum
import torch.nn .functional as F
from misc.torchutils import class2one_hot,simplex
from models.darnet_help.loss_help import FocalLoss, dernet_dice_loss
def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255):
"""
logSoftmax_with_loss
:param input: torch.Tensor, N*C*H*W
:param target: torch.Tensor, N*1*H*W,/ N*H*W
:param weight: torch.Tensor, C
:return: torch.Tensor [0]
"""
target = target.long()
if target.dim() == 4:
target = torch.squeeze(target, dim=1)
if input.shape[-1] != target.shape[-1]:
input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
return F.cross_entropy(input=input, target=target, weight=weight,
ignore_index=ignore_index, reduction=reduction)
def dice_loss(predicts,target,weight=None):
idc= [0, 1]
probs = torch.softmax(predicts, dim=1)
# target = target.unsqueeze(1)
target = class2one_hot(target, 7)
assert simplex(probs) and simplex(target)
pc = probs[:, idc, ...].type(torch.float32)
tc = target[:, idc, ...].type(torch.float32)
intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc)
union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))
divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10)
loss = divided.mean()
return loss
def ce_dice(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = 0.5 * ce_loss + 0.5 * dice_loss_
return loss
def dice(input, target, weight=None):
dice_loss_ = dice_loss(input, target)
return dice_loss_
def ce2_dice1(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = ce_loss + 0.5 * dice_loss_
return loss
def ce1_dice2(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = 0.5 * ce_loss + dice_loss_
return loss
def ce_scl(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = 0.5 * ce_loss + 0.5 * dice_loss_
return loss
def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75):
logit = logit_pixel.view(-1)
truth = truth_pixel.view(-1)
assert (logit.shape == truth.shape)
loss = F.binary_cross_entropy_with_logits(logit.float(), truth.float(), reduction='none')
pos = (truth > 0.5).float()
neg = (truth < 0.5).float()
pos_num = pos.sum().item() + 1e-12
neg_num = neg.sum().item() + 1e-12
loss = (weight_pos * pos * loss / pos_num + weight_neg * neg * loss / neg_num).sum()
return loss
class ChangeSimilarity(nn.Module):
"""input: x1, x2 multi-class predictions, c = class_num
label_change: changed part
"""
def __init__(self, reduction='mean'):
super(ChangeSimilarity, self).__init__()
self.loss_f = nn.CosineEmbeddingLoss(margin=0., reduction=reduction)
def forward(self, x1, x2, label_change):
b, c, h, w = x1.size()
x1 = F.softmax(x1, dim=1)
x2 = F.softmax(x2, dim=1)
x1 = x1.permute(0, 2, 3, 1)
x2 = x2.permute(0, 2, 3, 1)
x1 = torch.reshape(x1, [b * h * w, c])
x2 = torch.reshape(x2, [b * h * w, c])
label_unchange = ~label_change.bool()
target = label_unchange.float()
target = target - label_change.float()
target = torch.reshape(target, [b * h * w])
loss = self.loss_f(x1, x2, target)
return loss
def hybrid_loss(predictions, target, weight=[0,2,0.2,0.2,0.2,0.2]):
"""Calculating the loss"""
loss = 0
# gamma=0, alpha=None --> CE
# focal = FocalLoss(gamma=0, alpha=None)
# ssim = SSIM()
for i,prediction in enumerate(predictions):
bce = cross_entropy(prediction, target)
dice = dice_loss(prediction, target)
# ssimloss = ssim(prediction, target)
loss += weight[i]*(bce + dice) #- ssimloss
return loss
class BCL(nn.Module):
"""
batch-balanced contrastive loss
no-change,1
change,-1
"""
def __init__(self, margin=2.0):
super(BCL, self).__init__()
self.margin = margin
def forward(self, distance, label):
label[label == 1] = -1
label[label == 0] = 1
mask = (label != 255).float()
distance = distance * mask
pos_num = torch.sum((label==1).float())+0.0001
neg_num = torch.sum((label==-1).float())+0.0001
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
loss_2 = torch.sum((1-label) / 2 *
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
) / neg_num
loss = loss_1 + loss_2
return loss