File size: 4,953 Bytes
7ad9dfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn as nn
from torch import Tensor, einsum
import torch.nn .functional as F
from misc.torchutils import class2one_hot,simplex
from models.darnet_help.loss_help import FocalLoss, dernet_dice_loss
def cross_entropy(input, target, weight=None, reduction='mean',ignore_index=255):
"""
logSoftmax_with_loss
:param input: torch.Tensor, N*C*H*W
:param target: torch.Tensor, N*1*H*W,/ N*H*W
:param weight: torch.Tensor, C
:return: torch.Tensor [0]
"""
target = target.long()
if target.dim() == 4:
target = torch.squeeze(target, dim=1)
if input.shape[-1] != target.shape[-1]:
input = F.interpolate(input, size=target.shape[1:], mode='bilinear',align_corners=True)
return F.cross_entropy(input=input, target=target, weight=weight,
ignore_index=ignore_index, reduction=reduction)
def dice_loss(predicts,target,weight=None):
idc= [0, 1]
probs = torch.softmax(predicts, dim=1)
# target = target.unsqueeze(1)
target = class2one_hot(target, 7)
assert simplex(probs) and simplex(target)
pc = probs[:, idc, ...].type(torch.float32)
tc = target[:, idc, ...].type(torch.float32)
intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc)
union: Tensor = (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))
divided: Tensor = torch.ones_like(intersection) - (2 * intersection + 1e-10) / (union + 1e-10)
loss = divided.mean()
return loss
def ce_dice(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = 0.5 * ce_loss + 0.5 * dice_loss_
return loss
def dice(input, target, weight=None):
dice_loss_ = dice_loss(input, target)
return dice_loss_
def ce2_dice1(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = ce_loss + 0.5 * dice_loss_
return loss
def ce1_dice2(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = 0.5 * ce_loss + dice_loss_
return loss
def ce_scl(input, target, weight=None):
ce_loss = cross_entropy(input, target)
dice_loss_ = dice_loss(input, target)
loss = 0.5 * ce_loss + 0.5 * dice_loss_
return loss
def weighted_BCE_logits(logit_pixel, truth_pixel, weight_pos=0.25, weight_neg=0.75):
logit = logit_pixel.view(-1)
truth = truth_pixel.view(-1)
assert (logit.shape == truth.shape)
loss = F.binary_cross_entropy_with_logits(logit.float(), truth.float(), reduction='none')
pos = (truth > 0.5).float()
neg = (truth < 0.5).float()
pos_num = pos.sum().item() + 1e-12
neg_num = neg.sum().item() + 1e-12
loss = (weight_pos * pos * loss / pos_num + weight_neg * neg * loss / neg_num).sum()
return loss
class ChangeSimilarity(nn.Module):
"""input: x1, x2 multi-class predictions, c = class_num
label_change: changed part
"""
def __init__(self, reduction='mean'):
super(ChangeSimilarity, self).__init__()
self.loss_f = nn.CosineEmbeddingLoss(margin=0., reduction=reduction)
def forward(self, x1, x2, label_change):
b, c, h, w = x1.size()
x1 = F.softmax(x1, dim=1)
x2 = F.softmax(x2, dim=1)
x1 = x1.permute(0, 2, 3, 1)
x2 = x2.permute(0, 2, 3, 1)
x1 = torch.reshape(x1, [b * h * w, c])
x2 = torch.reshape(x2, [b * h * w, c])
label_unchange = ~label_change.bool()
target = label_unchange.float()
target = target - label_change.float()
target = torch.reshape(target, [b * h * w])
loss = self.loss_f(x1, x2, target)
return loss
def hybrid_loss(predictions, target, weight=[0,2,0.2,0.2,0.2,0.2]):
"""Calculating the loss"""
loss = 0
# gamma=0, alpha=None --> CE
# focal = FocalLoss(gamma=0, alpha=None)
# ssim = SSIM()
for i,prediction in enumerate(predictions):
bce = cross_entropy(prediction, target)
dice = dice_loss(prediction, target)
# ssimloss = ssim(prediction, target)
loss += weight[i]*(bce + dice) #- ssimloss
return loss
class BCL(nn.Module):
"""
batch-balanced contrastive loss
no-change,1
change,-1
"""
def __init__(self, margin=2.0):
super(BCL, self).__init__()
self.margin = margin
def forward(self, distance, label):
label[label == 1] = -1
label[label == 0] = 1
mask = (label != 255).float()
distance = distance * mask
pos_num = torch.sum((label==1).float())+0.0001
neg_num = torch.sum((label==-1).float())+0.0001
loss_1 = torch.sum((1+label) / 2 * torch.pow(distance, 2)) /pos_num
loss_2 = torch.sum((1-label) / 2 *
torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)
) / neg_num
loss = loss_1 + loss_2
return loss |