Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,212 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmpose.registry import MODELS
@MODELS.register_module()
class BCELoss(nn.Module):
"""Binary Cross Entropy loss.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of the loss. Default: 1.0.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
before output. Defaults to False.
"""
def __init__(self,
use_target_weight=False,
loss_weight=1.,
reduction='mean',
use_sigmoid=False):
super().__init__()
assert reduction in ('mean', 'sum', 'none'), f'the argument ' \
f'`reduction` should be either \'mean\', \'sum\' or \'none\', ' \
f'but got {reduction}'
self.reduction = reduction
self.use_sigmoid = use_sigmoid
criterion = F.binary_cross_entropy if use_sigmoid \
else F.binary_cross_entropy_with_logits
self.criterion = partial(criterion, reduction='none')
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
if self.use_target_weight:
assert target_weight is not None
loss = self.criterion(output, target)
if target_weight.dim() == 1:
target_weight = target_weight[:, None]
loss = (loss * target_weight)
else:
loss = self.criterion(output, target)
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss * self.loss_weight
@MODELS.register_module()
class JSDiscretLoss(nn.Module):
"""Discrete JS Divergence loss for DSNT with Gaussian Heatmap.
Modified from `the official implementation
<https://github.com/anibali/dsntnn/blob/master/dsntnn/__init__.py>`_.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
size_average (bool): Option to average the loss by the batch_size.
"""
def __init__(
self,
use_target_weight=True,
size_average: bool = True,
):
super(JSDiscretLoss, self).__init__()
self.use_target_weight = use_target_weight
self.size_average = size_average
self.kl_loss = nn.KLDivLoss(reduction='none')
def kl(self, p, q):
"""Kullback-Leibler Divergence."""
eps = 1e-24
kl_values = self.kl_loss((q + eps).log(), p)
return kl_values
def js(self, pred_hm, gt_hm):
"""Jensen-Shannon Divergence."""
m = 0.5 * (pred_hm + gt_hm)
js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m))
return js_values
def forward(self, pred_hm, gt_hm, target_weight=None):
"""Forward function.
Args:
pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps.
gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
Returns:
torch.Tensor: Loss value.
"""
if self.use_target_weight:
assert target_weight is not None
assert pred_hm.ndim >= target_weight.ndim
for i in range(pred_hm.ndim - target_weight.ndim):
target_weight = target_weight.unsqueeze(-1)
loss = self.js(pred_hm * target_weight, gt_hm * target_weight)
else:
loss = self.js(pred_hm, gt_hm)
if self.size_average:
loss /= len(gt_hm)
return loss.sum()
@MODELS.register_module()
class KLDiscretLoss(nn.Module):
"""Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing.
Modified from `the official implementation.
<https://github.com/leeyegy/SimCC>`_.
Args:
beta (float): Temperature factor of Softmax. Default: 1.0.
label_softmax (bool): Whether to use Softmax on labels.
Default: False.
label_beta (float): Temperature factor of Softmax on labels.
Default: 1.0.
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
mask (list[int]): Index of masked keypoints.
mask_weight (float): Weight of masked keypoints. Default: 1.0.
"""
def __init__(self,
beta=1.0,
label_softmax=False,
label_beta=10.0,
use_target_weight=True,
mask=None,
mask_weight=1.0):
super(KLDiscretLoss, self).__init__()
self.beta = beta
self.label_softmax = label_softmax
self.label_beta = label_beta
self.use_target_weight = use_target_weight
self.mask = mask
self.mask_weight = mask_weight
self.log_softmax = nn.LogSoftmax(dim=1)
self.kl_loss = nn.KLDivLoss(reduction='none')
def criterion(self, dec_outs, labels):
"""Criterion function."""
log_pt = self.log_softmax(dec_outs * self.beta)
if self.label_softmax:
labels = F.softmax(labels * self.label_beta, dim=1)
loss = torch.mean(self.kl_loss(log_pt, labels), dim=1)
return loss
def forward(self, pred_simcc, gt_simcc, target_weight):
"""Forward function.
Args:
pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of
x-axis and y-axis.
gt_simcc (Tuple[Tensor, Tensor]): Target representations.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
N, K, _ = pred_simcc[0].shape
loss = 0
if self.use_target_weight:
weight = target_weight.reshape(-1)
else:
weight = 1.
for pred, target in zip(pred_simcc, gt_simcc):
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1, target.size(-1))
t_loss = self.criterion(pred, target).mul(weight)
if self.mask is not None:
t_loss = t_loss.reshape(N, K)
t_loss[:, self.mask] = t_loss[:, self.mask] * self.mask_weight
loss = loss + t_loss.sum()
return loss / K
@MODELS.register_module()
class InfoNCELoss(nn.Module):
"""InfoNCE loss for training a discriminative representation space with a
contrastive manner.
`Representation Learning with Contrastive Predictive Coding
arXiv: <https://arxiv.org/abs/1611.05424>`_.
Args:
temperature (float, optional): The temperature to use in the softmax
function. Higher temperatures lead to softer probability
distributions. Defaults to 1.0.
loss_weight (float, optional): The weight to apply to the loss.
Defaults to 1.0.
"""
def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None:
super(InfoNCELoss, self).__init__()
assert temperature > 0, f'the argument `temperature` must be ' \
f'positive, but got {temperature}'
self.temp = temperature
self.loss_weight = loss_weight
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Computes the InfoNCE loss.
Args:
features (Tensor): A tensor containing the feature
representations of different samples.
Returns:
Tensor: A tensor of shape (1,) containing the InfoNCE loss.
"""
n = features.size(0)
features_norm = F.normalize(features, dim=1)
logits = features_norm.mm(features_norm.t()) / self.temp
targets = torch.arange(n, dtype=torch.long, device=features.device)
loss = F.cross_entropy(logits, targets, reduction='sum')
return loss * self.loss_weight
@MODELS.register_module()
class VariFocalLoss(nn.Module):
"""Varifocal loss.
Args:
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
reduction (str): Options are "none", "mean" and "sum".
loss_weight (float): Weight of the loss. Default: 1.0.
alpha (float): A balancing factor for the negative part of
Varifocal Loss. Defaults to 0.75.
gamma (float): Gamma parameter for the modulating factor.
Defaults to 2.0.
"""
def __init__(self,
use_target_weight=False,
loss_weight=1.,
reduction='mean',
alpha=0.75,
gamma=2.0):
super().__init__()
assert reduction in ('mean', 'sum', 'none'), f'the argument ' \
f'`reduction` should be either \'mean\', \'sum\' or \'none\', ' \
f'but got {reduction}'
self.reduction = reduction
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
self.alpha = alpha
self.gamma = gamma
def criterion(self, output, target):
label = (target > 1e-4).to(target)
weight = self.alpha * output.sigmoid().pow(
self.gamma) * (1 - label) + target
output = output.clip(min=-10, max=10)
vfl = (
F.binary_cross_entropy_with_logits(
output, target, reduction='none') * weight)
return vfl
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
if self.use_target_weight:
assert target_weight is not None
loss = self.criterion(output, target)
if target_weight.dim() == 1:
target_weight = target_weight.unsqueeze(1)
loss = (loss * target_weight)
else:
loss = self.criterion(output, target)
loss[torch.isinf(loss)] = 0.0
loss[torch.isnan(loss)] = 0.0
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss * self.loss_weight
|