Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,387 Bytes
a249588 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
from mmpose.registry import MODELS
from mmpose.structures.bbox import bbox_overlaps
@MODELS.register_module()
class IoULoss(nn.Module):
"""Binary Cross Entropy loss.
Args:
reduction (str): Options are "none", "mean" and "sum".
eps (float): Epsilon to avoid log(0).
loss_weight (float): Weight of the loss. Default: 1.0.
mode (str): Loss scaling mode, including "linear", "square", and "log".
Default: 'log'
"""
def __init__(self,
reduction='mean',
mode='log',
eps: float = 1e-16,
loss_weight=1.):
super().__init__()
assert reduction in ('mean', 'sum', 'none'), f'the argument ' \
f'`reduction` should be either \'mean\', \'sum\' or \'none\', ' \
f'but got {reduction}'
assert mode in ('linear', 'square', 'log'), f'the argument ' \
f'`reduction` should be either \'linear\', \'square\' or ' \
f'\'log\', but got {mode}'
self.reduction = reduction
self.criterion = partial(F.cross_entropy, reduction='none')
self.loss_weight = loss_weight
self.mode = mode
self.eps = eps
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
"""
ious = bbox_overlaps(
output, target, is_aligned=True).clamp(min=self.eps)
if self.mode == 'linear':
loss = 1 - ious
elif self.mode == 'square':
loss = 1 - ious.pow(2)
elif self.mode == 'log':
loss = -ious.log()
else:
raise NotImplementedError
if target_weight is not None:
for i in range(loss.ndim - target_weight.ndim):
target_weight = target_weight.unsqueeze(-1)
loss = loss * target_weight
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss * self.loss_weight
|