|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor, einsum |
|
|
import torch.nn .functional as F |
|
|
from misc.torchutils import class2one_hot,simplex |
|
|
from models.darnet_help.loss_help import FocalLoss, dernet_dice_loss |
|
|
|
|
|
def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255): |
|
|
""" |
|
|
logSoftmax_with_loss |
|
|
:param input: torch.Tensor, N*C*H*W |
|
|
:param target: torch.Tensor, N*1*H*W,/ N*H*W |
|
|
:param weight: torch.Tensor, C |
|
|
:return: torch.Tensor [0] |
|
|
""" |
|
|
target = target.long() |
|
|
if target.dim() == 4: |
|
|
target = torch.squeeze(target, dim=1) |
|
|
if input.shape[-1] != target.shape[-1]: |
|
|
input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True) |
|
|
|
|
|
return F.cross_entropy(input=input, target=target, weight=weight, |
|
|
ignore_index=ignore_index, reduction=reduction) |
|
|
|
|
|
|
|
|
def dice_loss(predicts,target,weight=None): |
|
|
idc= [0, 1] |
|
|
probs = torch.softmax(predicts, dim=1) |
|
|
|
|
|
target = class2one_hot(target, 7) |
|
|
assert simplex(probs) and simplex(target) |
|
|
|
|
|
pc = probs[:, idc, ...].type(torch.float32) |
|
|
tc = target[:, idc, ...].type(torch.float32) |
|
|
intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc) |
|
|
union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) |
|
|
|
|
|
divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10) |
|
|
|
|
|
loss = divided.mean() |
|
|
return loss |
|
|
|
|
|
def ce_dice(input, target, weight=None): |
|
|
ce_loss = cross_entropy(input, target) |
|
|
dice_loss_ = dice_loss(input, target) |
|
|
loss = 0.5 * ce_loss + 0.5 * dice_loss_ |
|
|
return loss |
|
|
|
|
|
def dice(input, target, weight=None): |
|
|
dice_loss_ = dice_loss(input, target) |
|
|
return dice_loss_ |
|
|
|
|
|
def ce2_dice1(input, target, weight=None): |
|
|
ce_loss = cross_entropy(input, target) |
|
|
dice_loss_ = dice_loss(input, target) |
|
|
loss = ce_loss + 0.5 * dice_loss_ |
|
|
return loss |
|
|
|
|
|
def ce1_dice2(input, target, weight=None): |
|
|
ce_loss = cross_entropy(input, target) |
|
|
dice_loss_ = dice_loss(input, target) |
|
|
loss = 0.5 * ce_loss + dice_loss_ |
|
|
return loss |
|
|
|
|
|
def ce_scl(input, target, weight=None): |
|
|
ce_loss = cross_entropy(input, target) |
|
|
dice_loss_ = dice_loss(input, target) |
|
|
loss = 0.5 * ce_loss + 0.5 * dice_loss_ |
|
|
return loss |
|
|
|
|
|
|
|
|
def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75): |
|
|
logit = logit_pixel.view(-1) |
|
|
truth = truth_pixel.view(-1) |
|
|
assert (logit.shape == truth.shape) |
|
|
|
|
|
loss = F.binary_cross_entropy_with_logits(logit.float(), truth.float(), reduction='none') |
|
|
|
|
|
pos = (truth > 0.5).float() |
|
|
neg = (truth < 0.5).float() |
|
|
pos_num = pos.sum().item() + 1e-12 |
|
|
neg_num = neg.sum().item() + 1e-12 |
|
|
loss = (weight_pos * pos * loss / pos_num + weight_neg * neg * loss / neg_num).sum() |
|
|
|
|
|
return loss |
|
|
|
|
|
class ChangeSimilarity(nn.Module): |
|
|
"""input: x1, x2 multi-class predictions, c = class_num |
|
|
label_change: changed part |
|
|
""" |
|
|
|
|
|
def __init__(self, reduction='mean'): |
|
|
super(ChangeSimilarity, self).__init__() |
|
|
self.loss_f = nn.CosineEmbeddingLoss(margin=0., reduction=reduction) |
|
|
|
|
|
def forward(self, x1, x2, label_change): |
|
|
b, c, h, w = x1.size() |
|
|
x1 = F.softmax(x1, dim=1) |
|
|
x2 = F.softmax(x2, dim=1) |
|
|
x1 = x1.permute(0, 2, 3, 1) |
|
|
x2 = x2.permute(0, 2, 3, 1) |
|
|
x1 = torch.reshape(x1, [b * h * w, c]) |
|
|
x2 = torch.reshape(x2, [b * h * w, c]) |
|
|
|
|
|
label_unchange = ~label_change.bool() |
|
|
target = label_unchange.float() |
|
|
target = target - label_change.float() |
|
|
target = torch.reshape(target, [b * h * w]) |
|
|
|
|
|
loss = self.loss_f(x1, x2, target) |
|
|
return loss |
|
|
|
|
|
def hybrid_loss(predictions, target, weight=[0,2,0.2,0.2,0.2,0.2]): |
|
|
"""Calculating the loss""" |
|
|
loss = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i,prediction in enumerate(predictions): |
|
|
|
|
|
bce = cross_entropy(prediction, target) |
|
|
dice = dice_loss(prediction, target) |
|
|
|
|
|
loss += weight[i]*(bce + dice) |
|
|
|
|
|
return loss |
|
|
|
|
|
class BCL(nn.Module): |
|
|
""" |
|
|
batch-balanced contrastive loss |
|
|
no-change,1 |
|
|
change,-1 |
|
|
""" |
|
|
def __init__(self, margin=2.0): |
|
|
super(BCL, self).__init__() |
|
|
self.margin = margin |
|
|
|
|
|
def forward(self, distance, label): |
|
|
label[label == 1] = -1 |
|
|
label[label == 0] = 1 |
|
|
|
|
|
mask = (label != 255).float() |
|
|
distance = distance * mask |
|
|
|
|
|
pos_num = torch.sum((label==1).float())+0.0001 |
|
|
neg_num = torch.sum((label==-1).float())+0.0001 |
|
|
|
|
|
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num |
|
|
loss_2 = torch.sum((1-label) / 2 * |
|
|
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2) |
|
|
) / neg_num |
|
|
loss = loss_1 + loss_2 |
|
|
return loss |