Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from mmpose.registry import MODELS | |
class KeypointMSELoss(nn.Module): | |
"""MSE loss for heatmaps. | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
skip_empty_channel (bool): If ``True``, heatmap channels with no | |
non-zero value (which means no visible ground-truth keypoint | |
in the image) will not be used to calculate the loss. Defaults to | |
``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
skip_empty_channel: bool = False, | |
loss_weight: float = 1.): | |
super().__init__() | |
self.use_target_weight = use_target_weight | |
self.skip_empty_channel = skip_empty_channel | |
self.loss_weight = loss_weight | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None, | |
per_keypoint: bool = False, | |
per_pixel: bool = False) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W] | |
target (Tensor): The target heatmaps with shape [B, K, H, W] | |
target_weights (Tensor, optional): The target weights of differet | |
keypoints, with shape [B, K] (keypoint-wise) or | |
[B, K, H, W] (pixel-wise). | |
mask (Tensor, optional): The masks of valid heatmap pixels in | |
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
be applied. Defaults to ``None`` | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
_mask = self._get_mask(target, target_weights, mask) | |
_loss = F.mse_loss(output, target, reduction='none') | |
if _mask is not None: | |
loss = _loss * _mask | |
if per_pixel: | |
pass | |
elif per_keypoint: | |
loss = loss.mean(dim=(2, 3)) | |
else: | |
loss = loss.mean() | |
return loss * self.loss_weight | |
def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
mask: Optional[Tensor]) -> Optional[Tensor]: | |
"""Generate the heatmap mask w.r.t. the given mask, target weight and | |
`skip_empty_channel` setting. | |
Returns: | |
Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
needed. | |
""" | |
# Given spatial mask | |
if mask is not None: | |
# check mask has matching type with target | |
assert (mask.ndim == target.ndim and all( | |
d_m == d_t or d_m == 1 | |
for d_m, d_t in zip(mask.shape, target.shape))), ( | |
f'mask and target have mismatched shapes {mask.shape} v.s.' | |
f'{target.shape}') | |
# Mask by target weights (keypoint-wise mask) | |
if target_weights is not None: | |
# check target weight has matching shape with target | |
assert (target_weights.ndim in (2, 4) and target_weights.shape | |
== target.shape[:target_weights.ndim]), ( | |
'target_weights and target have mismatched shapes ' | |
f'{target_weights.shape} v.s. {target.shape}') | |
ndim_pad = target.ndim - target_weights.ndim | |
_mask = target_weights.view(target_weights.shape + | |
(1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
# Mask by ``skip_empty_channel`` | |
if self.skip_empty_channel: | |
_mask = (target != 0).flatten(2).any(dim=2) | |
ndim_pad = target.ndim - _mask.ndim | |
_mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
return mask | |
class CombinedTargetMSELoss(nn.Module): | |
"""MSE loss for combined target. | |
CombinedTarget: The combination of classification target | |
(response map) and regression target (offset map). | |
Paper ref: Huang et al. The Devil is in the Details: Delving into | |
Unbiased Data Processing for Human Pose Estimation (CVPR 2020). | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
loss_weight: float = 1.): | |
super().__init__() | |
self.criterion = nn.MSELoss(reduction='mean') | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def forward(self, output: Tensor, target: Tensor, | |
target_weights: Tensor) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_channels: C | |
- heatmaps height: H | |
- heatmaps weight: W | |
- num_keypoints: K | |
Here, C = 3 * K | |
Args: | |
output (Tensor): The output feature maps with shape [B, C, H, W]. | |
target (Tensor): The target feature maps with shape [B, C, H, W]. | |
target_weights (Tensor): The target weights of differet keypoints, | |
with shape [B, K]. | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
batch_size = output.size(0) | |
num_channels = output.size(1) | |
heatmaps_pred = output.reshape( | |
(batch_size, num_channels, -1)).split(1, 1) | |
heatmaps_gt = target.reshape( | |
(batch_size, num_channels, -1)).split(1, 1) | |
loss = 0. | |
num_joints = num_channels // 3 | |
for idx in range(num_joints): | |
heatmap_pred = heatmaps_pred[idx * 3].squeeze() | |
heatmap_gt = heatmaps_gt[idx * 3].squeeze() | |
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() | |
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() | |
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() | |
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() | |
if self.use_target_weight: | |
target_weight = target_weights[:, idx, None] | |
heatmap_pred = heatmap_pred * target_weight | |
heatmap_gt = heatmap_gt * target_weight | |
# classification loss | |
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) | |
# regression loss | |
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, | |
heatmap_gt * offset_x_gt) | |
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, | |
heatmap_gt * offset_y_gt) | |
return loss / num_joints * self.loss_weight | |
class KeypointOHKMMSELoss(nn.Module): | |
"""MSE loss with online hard keypoint mining. | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
topk (int): Only top k joint losses are kept. Defaults to 8 | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
topk: int = 8, | |
loss_weight: float = 1.): | |
super().__init__() | |
assert topk > 0 | |
self.criterion = nn.MSELoss(reduction='none') | |
self.use_target_weight = use_target_weight | |
self.topk = topk | |
self.loss_weight = loss_weight | |
def _ohkm(self, losses: Tensor) -> Tensor: | |
"""Online hard keypoint mining. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
Args: | |
loss (Tensor): The losses with shape [B, K] | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
ohkm_loss = 0. | |
B = losses.shape[0] | |
for i in range(B): | |
sub_loss = losses[i] | |
_, topk_idx = torch.topk( | |
sub_loss, k=self.topk, dim=0, sorted=False) | |
tmp_loss = torch.gather(sub_loss, 0, topk_idx) | |
ohkm_loss += torch.sum(tmp_loss) / self.topk | |
ohkm_loss /= B | |
return ohkm_loss | |
def forward(self, output: Tensor, target: Tensor, | |
target_weights: Tensor) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W]. | |
target (Tensor): The target heatmaps with shape [B, K, H, W]. | |
target_weights (Tensor): The target weights of differet keypoints, | |
with shape [B, K]. | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
num_keypoints = output.size(1) | |
if num_keypoints < self.topk: | |
raise ValueError(f'topk ({self.topk}) should not be ' | |
f'larger than num_keypoints ({num_keypoints}).') | |
losses = [] | |
for idx in range(num_keypoints): | |
if self.use_target_weight: | |
target_weight = target_weights[:, idx, None, None] | |
losses.append( | |
self.criterion(output[:, idx] * target_weight, | |
target[:, idx] * target_weight)) | |
else: | |
losses.append(self.criterion(output[:, idx], target[:, idx])) | |
losses = [loss.mean(dim=(1, 2)).unsqueeze(dim=1) for loss in losses] | |
losses = torch.cat(losses, dim=1) | |
return self._ohkm(losses) * self.loss_weight | |
class AdaptiveWingLoss(nn.Module): | |
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face | |
Alignment via Heatmap Regression' Wang et al. ICCV'2019. | |
Args: | |
alpha (float), omega (float), epsilon (float), theta (float) | |
are hyper-parameters. | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
loss_weight (float): Weight of the loss. Default: 1.0. | |
""" | |
def __init__(self, | |
alpha=2.1, | |
omega=14, | |
epsilon=1, | |
theta=0.5, | |
use_target_weight=False, | |
loss_weight=1.): | |
super().__init__() | |
self.alpha = float(alpha) | |
self.omega = float(omega) | |
self.epsilon = float(epsilon) | |
self.theta = float(theta) | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def criterion(self, pred, target): | |
"""Criterion of wingloss. | |
Note: | |
batch_size: N | |
num_keypoints: K | |
Args: | |
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps. | |
target (torch.Tensor[NxKxHxW]): Target heatmaps. | |
""" | |
H, W = pred.shape[2:4] | |
delta = (target - pred).abs() | |
A = self.omega * ( | |
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
) * (self.alpha - target) * (torch.pow( | |
self.theta / self.epsilon, | |
self.alpha - target - 1)) * (1 / self.epsilon) | |
C = self.theta * A - self.omega * torch.log( | |
1 + torch.pow(self.theta / self.epsilon, self.alpha - target)) | |
losses = torch.where( | |
delta < self.theta, | |
self.omega * | |
torch.log(1 + | |
torch.pow(delta / self.epsilon, self.alpha - target)), | |
A * delta - C) | |
return torch.mean(losses) | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None): | |
"""Forward function. | |
Note: | |
batch_size: N | |
num_keypoints: K | |
Args: | |
output (torch.Tensor[N, K, H, W]): Output heatmaps. | |
target (torch.Tensor[N, K, H, W]): Target heatmaps. | |
target_weight (torch.Tensor[N, K]): | |
Weights across different joint types. | |
""" | |
if self.use_target_weight: | |
assert (target_weights.ndim in (2, 4) and target_weights.shape | |
== target.shape[:target_weights.ndim]), ( | |
'target_weights and target have mismatched shapes ' | |
f'{target_weights.shape} v.s. {target.shape}') | |
ndim_pad = target.ndim - target_weights.ndim | |
target_weights = target_weights.view(target_weights.shape + | |
(1, ) * ndim_pad) | |
loss = self.criterion(output * target_weights, | |
target * target_weights) | |
else: | |
loss = self.criterion(output, target) | |
return loss * self.loss_weight | |
class FocalHeatmapLoss(KeypointMSELoss): | |
"""A class for calculating the modified focal loss for heatmap prediction. | |
This loss function is exactly the same as the one used in CornerNet. It | |
runs faster and costs a little bit more memory. | |
`CornerNet: Detecting Objects as Paired Keypoints | |
arXiv: <https://arxiv.org/abs/1808.01244>`_. | |
Arguments: | |
alpha (int): The alpha parameter in the focal loss equation. | |
beta (int): The beta parameter in the focal loss equation. | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
skip_empty_channel (bool): If ``True``, heatmap channels with no | |
non-zero value (which means no visible ground-truth keypoint | |
in the image) will not be used to calculate the loss. Defaults to | |
``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
alpha: int = 2, | |
beta: int = 4, | |
use_target_weight: bool = False, | |
skip_empty_channel: bool = False, | |
loss_weight: float = 1.0): | |
super(FocalHeatmapLoss, self).__init__(use_target_weight, | |
skip_empty_channel, loss_weight) | |
self.alpha = alpha | |
self.beta = beta | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None) -> Tensor: | |
"""Calculate the modified focal loss for heatmap prediction. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W] | |
target (Tensor): The target heatmaps with shape [B, K, H, W] | |
target_weights (Tensor, optional): The target weights of differet | |
keypoints, with shape [B, K] (keypoint-wise) or | |
[B, K, H, W] (pixel-wise). | |
mask (Tensor, optional): The masks of valid heatmap pixels in | |
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
be applied. Defaults to ``None`` | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
_mask = self._get_mask(target, target_weights, mask) | |
pos_inds = target.eq(1).float() | |
neg_inds = target.lt(1).float() | |
if _mask is not None: | |
pos_inds = pos_inds * _mask | |
neg_inds = neg_inds * _mask | |
neg_weights = torch.pow(1 - target, self.beta) | |
pos_loss = torch.log(output) * torch.pow(1 - output, | |
self.alpha) * pos_inds | |
neg_loss = torch.log(1 - output) * torch.pow( | |
output, self.alpha) * neg_weights * neg_inds | |
num_pos = pos_inds.float().sum() | |
if num_pos == 0: | |
loss = -neg_loss.sum() | |
else: | |
loss = -(pos_loss.sum() + neg_loss.sum()) / num_pos | |
return loss * self.loss_weight | |
class MLECCLoss(nn.Module): | |
"""Maximum Likelihood Estimation loss for Coordinate Classification. | |
This loss function is designed to work with coordinate classification | |
problems where the likelihood of each target coordinate is maximized. | |
Args: | |
reduction (str): Specifies the reduction to apply to the output: | |
'none' | 'mean' | 'sum'. Default: 'mean'. | |
mode (str): Specifies the mode of calculating loss: | |
'linear' | 'square' | 'log'. Default: 'log'. | |
use_target_weight (bool): If True, uses weighted loss. Different | |
joint types may have different target weights. Defaults to False. | |
loss_weight (float): Weight of the loss. Defaults to 1.0. | |
Raises: | |
AssertionError: If the `reduction` or `mode` arguments are not in the | |
expected choices. | |
NotImplementedError: If the selected mode is not implemented. | |
""" | |
def __init__(self, | |
reduction: str = 'mean', | |
mode: str = 'log', | |
use_target_weight: bool = False, | |
loss_weight: float = 1.0): | |
super().__init__() | |
assert reduction in ('mean', 'sum', 'none'), \ | |
f"`reduction` should be either 'mean', 'sum', or 'none', " \ | |
f'but got {reduction}' | |
assert mode in ('linear', 'square', 'log'), \ | |
f"`mode` should be either 'linear', 'square', or 'log', " \ | |
f'but got {mode}' | |
self.reduction = reduction | |
self.mode = mode | |
self.use_target_weight = use_target_weight | |
self.loss_weight = loss_weight | |
def forward(self, outputs, targets, target_weight=None): | |
"""Forward pass for the MLECCLoss. | |
Args: | |
outputs (torch.Tensor): The predicted outputs. | |
targets (torch.Tensor): The ground truth targets. | |
target_weight (torch.Tensor, optional): Optional tensor of weights | |
for each target. | |
Returns: | |
torch.Tensor: Calculated loss based on the specified mode and | |
reduction. | |
""" | |
assert len(outputs) == len(targets), \ | |
'Outputs and targets must have the same length' | |
prob = 1.0 | |
for o, t in zip(outputs, targets): | |
prob *= (o * t).sum(dim=-1) | |
if self.mode == 'linear': | |
loss = 1.0 - prob | |
elif self.mode == 'square': | |
loss = 1.0 - prob.pow(2) | |
elif self.mode == 'log': | |
loss = -torch.log(prob + 1e-4) | |
loss[torch.isnan(loss)] = 0.0 | |
if self.use_target_weight: | |
assert target_weight is not None | |
for i in range(loss.ndim - target_weight.ndim): | |
target_weight = target_weight.unsqueeze(-1) | |
loss = loss * target_weight | |
if self.reduction == 'sum': | |
loss = loss.flatten(1).sum(dim=1) | |
elif self.reduction == 'mean': | |
loss = loss.flatten(1).mean(dim=1) | |
return loss * self.loss_weight | |
class OKSHeatmapLoss(nn.Module): | |
"""OKS-based loss for heatmaps. | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
skip_empty_channel (bool): If ``True``, heatmap channels with no | |
non-zero value (which means no visible ground-truth keypoint | |
in the image) will not be used to calculate the loss. Defaults to | |
``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
skip_empty_channel: bool = False, | |
smoothing_weight: float = 0.2, | |
gaussian_weight: float = 0.0, | |
loss_weight: float = 1., | |
oks_type: str = "minus"): | |
super().__init__() | |
self.use_target_weight = use_target_weight | |
self.skip_empty_channel = skip_empty_channel | |
self.loss_weight = loss_weight | |
self.smoothing_weight = smoothing_weight | |
self.gaussian_weight = gaussian_weight | |
self.oks_type = oks_type.lower() | |
assert self.oks_type in ["minus", "plus", "both"] | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None, | |
per_pixel: bool = False, | |
per_keypoint: bool = False) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W] | |
target (Tensor): The target heatmaps with shape [B, K, H, W] | |
target_weights (Tensor, optional): The target weights of differet | |
keypoints, with shape [B, K] (keypoint-wise) or | |
[B, K, H, W] (pixel-wise). | |
mask (Tensor, optional): The masks of valid heatmap pixels in | |
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
be applied. Defaults to ``None`` | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
assert target.max() <= 1, 'target should be normalized' | |
assert target.min() >= 0, 'target should be normalized' | |
B, K, H, W = output.shape | |
_mask = self._get_mask(target, target_weights, mask) | |
oks_minus = output * (1-target) | |
oks_plus = (1-output) * (target) | |
if self.oks_type == "both": | |
oks = (oks_minus + oks_plus) / 2 | |
elif self.oks_type == "minus": | |
oks = oks_minus | |
elif self.oks_type == "plus": | |
oks = oks_plus | |
else: | |
raise ValueError(f"oks_type {self.oks_type} not recognized") | |
mse = F.mse_loss(output, target, reduction='none') | |
# Smoothness loss | |
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).view(1, 1, 3, 3).to(output.device) | |
sobel_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).view(1, 1, 3, 3).to(output.device) | |
gradient_x = F.conv2d(output.reshape(B*K, 1, H, W), sobel_x, padding='same') | |
gradient_y = F.conv2d(output.reshape(B*K, 1, H, W), sobel_y, padding='same') | |
gradient = (gradient_x**2 + gradient_y**2).reshape(B, K, H, W) | |
if _mask is not None: | |
oks = oks * _mask | |
mse = mse * _mask | |
gradient = gradient * _mask | |
oks_minus_weight = ( | |
1 - self.smoothing_weight - self.gaussian_weight | |
) | |
if per_pixel: | |
loss = ( | |
self.smoothing_weight * gradient + | |
oks_minus_weight * oks + | |
self.gaussian_weight * mse | |
) | |
elif per_keypoint: | |
max_gradient, _ = gradient.reshape((B, K, H*W)).max(dim=-1) | |
loss = ( | |
oks_minus_weight * oks.sum(dim=(2, 3)) + | |
self.smoothing_weight * max_gradient + | |
self.gaussian_weight * mse.mean(dim=(2, 3)) | |
) | |
else: | |
max_gradient, _ = gradient.reshape((B, K, H*W)).max(dim=-1) | |
loss = ( | |
oks_minus_weight * oks.sum(dim=(2, 3)) + | |
self.smoothing_weight * max_gradient + | |
self.gaussian_weight * mse.mean(dim=(2, 3)) | |
).mean() | |
return loss * self.loss_weight | |
def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
mask: Optional[Tensor]) -> Optional[Tensor]: | |
"""Generate the heatmap mask w.r.t. the given mask, target weight and | |
`skip_empty_channel` setting. | |
Returns: | |
Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
needed. | |
""" | |
# Given spatial mask | |
if mask is not None: | |
# check mask has matching type with target | |
assert (mask.ndim == target.ndim and all( | |
d_m == d_t or d_m == 1 | |
for d_m, d_t in zip(mask.shape, target.shape))), ( | |
f'mask and target have mismatched shapes {mask.shape} v.s.' | |
f'{target.shape}') | |
# Mask by target weights (keypoint-wise mask) | |
if target_weights is not None: | |
# check target weight has matching shape with target | |
assert (target_weights.ndim in (2, 4) and target_weights.shape | |
== target.shape[:target_weights.ndim]), ( | |
'target_weights and target have mismatched shapes ' | |
f'{target_weights.shape} v.s. {target.shape}') | |
ndim_pad = target.ndim - target_weights.ndim | |
_mask = target_weights.view(target_weights.shape + | |
(1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
# Mask by ``skip_empty_channel`` | |
if self.skip_empty_channel: | |
_mask = (target != 0).flatten(2).any(dim=2) | |
ndim_pad = target.ndim - _mask.ndim | |
_mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
return mask | |
class CalibrationLoss(nn.Module): | |
"""OKS-based loss for heatmaps. | |
Args: | |
use_target_weight (bool): Option to use weighted MSE loss. | |
Different joint types may have different target weights. | |
Defaults to ``False`` | |
skip_empty_channel (bool): If ``True``, heatmap channels with no | |
non-zero value (which means no visible ground-truth keypoint | |
in the image) will not be used to calculate the loss. Defaults to | |
``False`` | |
loss_weight (float): Weight of the loss. Defaults to 1.0 | |
""" | |
def __init__(self, | |
use_target_weight: bool = False, | |
skip_empty_channel: bool = False, | |
loss_weight: float = 1., | |
ignore_bottom_percentile: float = 0.7): | |
super().__init__() | |
self.use_target_weight = use_target_weight | |
self.skip_empty_channel = skip_empty_channel | |
self.loss_weight = loss_weight | |
self.ignore_bottom_percentile = ignore_bottom_percentile | |
def forward(self, | |
output: Tensor, | |
target: Tensor, | |
target_weights: Optional[Tensor] = None, | |
mask: Optional[Tensor] = None, | |
per_pixel: bool = False, | |
per_keypoint: bool = False) -> Tensor: | |
"""Forward function of loss. | |
Note: | |
- batch_size: B | |
- num_keypoints: K | |
- heatmaps height: H | |
- heatmaps weight: W | |
Args: | |
output (Tensor): The output heatmaps with shape [B, K, H, W] | |
target (Tensor): The target heatmaps with shape [B, K, H, W] | |
target_weights (Tensor, optional): The target weights of differet | |
keypoints, with shape [B, K] (keypoint-wise) or | |
[B, K, H, W] (pixel-wise). | |
mask (Tensor, optional): The masks of valid heatmap pixels in | |
shape [B, K, H, W] or [B, 1, H, W]. If ``None``, no mask will | |
be applied. Defaults to ``None`` | |
Returns: | |
Tensor: The calculated loss. | |
""" | |
assert target.max() <= 1, 'target should be normalized' | |
assert target.min() >= 0, 'target should be normalized' | |
B, K, H, W = output.shape | |
_mask = self._get_mask(target, target_weights, mask) | |
pred_probs = output * target | |
pred_probs_sum = pred_probs.sum(dim=(2,3)) | |
# threshold = torch.quantile(pred_probs_sum.detach(), self.ignore_bottom_percentile) | |
# _mask = _mask * (pred_probs_sum > self.ignore_bottom_percentile).view(B, K, 1, 1) | |
# print() | |
# tmp = -torch.log(pred_probs_sum.flatten() + 1e-10)[:, None] | |
# tmp = torch.cat([pred_probs_sum.flatten()[:, None], tmp, _mask.reshape(tmp.shape)], dim=1) | |
# print(tmp[:5, :]) | |
if per_pixel: | |
cross_entropy = -torch.log(pred_probs + 1e-10) | |
loss = cross_entropy * _mask | |
elif per_keypoint: | |
cross_entropy = -torch.log(pred_probs_sum + 1e-10) | |
loss = cross_entropy * _mask | |
else: | |
cross_entropy = -torch.log(pred_probs_sum + 1e-10) | |
loss = cross_entropy * _mask | |
loss = loss.mean() | |
return loss * self.loss_weight | |
def _get_mask(self, target: Tensor, target_weights: Optional[Tensor], | |
mask: Optional[Tensor]) -> Optional[Tensor]: | |
"""Generate the heatmap mask w.r.t. the given mask, target weight and | |
`skip_empty_channel` setting. | |
Returns: | |
Tensor: The mask in shape (B, K, *) or ``None`` if no mask is | |
needed. | |
""" | |
# Given spatial mask | |
if mask is not None: | |
# check mask has matching type with target | |
assert (mask.ndim == target.ndim and all( | |
d_m == d_t or d_m == 1 | |
for d_m, d_t in zip(mask.shape, target.shape))), ( | |
f'mask and target have mismatched shapes {mask.shape} v.s.' | |
f'{target.shape}') | |
# Mask by target weights (keypoint-wise mask) | |
if target_weights is not None: | |
# check target weight has matching shape with target | |
assert (target_weights.ndim in (2, 4) and target_weights.shape | |
== target.shape[:target_weights.ndim]), ( | |
'target_weights and target have mismatched shapes ' | |
f'{target_weights.shape} v.s. {target.shape}') | |
ndim_pad = target.ndim - target_weights.ndim | |
_mask = target_weights.view(target_weights.shape + | |
(1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
# Mask by ``skip_empty_channel`` | |
if self.skip_empty_channel: | |
_mask = (target != 0).flatten(2).any(dim=2) | |
ndim_pad = target.ndim - _mask.ndim | |
_mask = _mask.view(_mask.shape + (1, ) * ndim_pad) | |
if mask is None: | |
mask = _mask | |
else: | |
mask = mask * _mask | |
return mask | |