Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
from functools import partial | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmpose.registry import MODELS | |
from mmpose.structures.bbox import bbox_overlaps | |
class IoULoss(nn.Module): | |
"""Binary Cross Entropy loss. | |
Args: | |
reduction (str): Options are "none", "mean" and "sum". | |
eps (float): Epsilon to avoid log(0). | |
loss_weight (float): Weight of the loss. Default: 1.0. | |
mode (str): Loss scaling mode, including "linear", "square", and "log". | |
Default: 'log' | |
""" | |
def __init__(self, | |
reduction='mean', | |
mode='log', | |
eps: float = 1e-16, | |
loss_weight=1.): | |
super().__init__() | |
assert reduction in ('mean', 'sum', 'none'), f'the argument ' \ | |
f'`reduction` should be either \'mean\', \'sum\' or \'none\', ' \ | |
f'but got {reduction}' | |
assert mode in ('linear', 'square', 'log'), f'the argument ' \ | |
f'`reduction` should be either \'linear\', \'square\' or ' \ | |
f'\'log\', but got {mode}' | |
self.reduction = reduction | |
self.criterion = partial(F.cross_entropy, reduction='none') | |
self.loss_weight = loss_weight | |
self.mode = mode | |
self.eps = eps | |
def forward(self, output, target, target_weight=None): | |
"""Forward function. | |
Note: | |
- batch_size: N | |
- num_labels: K | |
Args: | |
output (torch.Tensor[N, K]): Output classification. | |
target (torch.Tensor[N, K]): Target classification. | |
""" | |
ious = bbox_overlaps( | |
output, target, is_aligned=True).clamp(min=self.eps) | |
if self.mode == 'linear': | |
loss = 1 - ious | |
elif self.mode == 'square': | |
loss = 1 - ious.pow(2) | |
elif self.mode == 'log': | |
loss = -ious.log() | |
else: | |
raise NotImplementedError | |
if target_weight is not None: | |
for i in range(loss.ndim - target_weight.ndim): | |
target_weight = target_weight.unsqueeze(-1) | |
loss = loss * target_weight | |
if self.reduction == 'sum': | |
loss = loss.sum() | |
elif self.reduction == 'mean': | |
loss = loss.mean() | |
return loss * self.loss_weight | |