|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
|
|
class JointsMSELoss(nn.Module): |
|
def __init__(self, use_target_weight): |
|
super(JointsMSELoss, self).__init__() |
|
self.criterion = nn.MSELoss(reduction='mean') |
|
self.use_target_weight = use_target_weight |
|
|
|
def forward(self, output, target, target_weight): |
|
batch_size = output.size(0) |
|
num_joints = output.size(1) |
|
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
loss = 0 |
|
|
|
for idx in range(num_joints): |
|
heatmap_pred = heatmaps_pred[idx].squeeze() |
|
heatmap_gt = heatmaps_gt[idx].squeeze() |
|
if self.use_target_weight: |
|
loss += 0.5 * self.criterion( |
|
heatmap_pred.mul(target_weight[:, idx]), |
|
heatmap_gt.mul(target_weight[:, idx]) |
|
) |
|
else: |
|
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) |
|
|
|
return loss / num_joints |
|
|
|
|
|
class NMTCritierion(nn.Module): |
|
def __init__(self, label_smoothing=0.0): |
|
super(NMTCritierion, self).__init__() |
|
self.label_smoothing = label_smoothing |
|
self.LogSoftmax = nn.LogSoftmax(dim=1) |
|
|
|
if label_smoothing > 0: |
|
self.criterion_ = nn.KLDivLoss(reduction='none') |
|
else: |
|
self.criterion_ = nn.NLLLoss(reduction='none', ignore_index=100000) |
|
self.confidence = 1.0 - label_smoothing |
|
|
|
def _smooth_label(self, num_tokens): |
|
one_hot = torch.randn(1, num_tokens) |
|
one_hot.fill_(self.label_smoothing / (num_tokens - 1)) |
|
return one_hot |
|
|
|
def _bottle(self, v): |
|
return v.view(-1, v.size(2)) |
|
|
|
def criterion(self, dec_outs, labels): |
|
scores = self.LogSoftmax(dec_outs) |
|
num_tokens = scores.size(-1) |
|
|
|
|
|
gtruth = labels.view(-1) |
|
if self.confidence < 1: |
|
tdata = gtruth.detach() |
|
one_hot = self._smooth_label(num_tokens) |
|
if labels.is_cuda: |
|
one_hot = one_hot.cuda() |
|
tmp_ = one_hot.repeat(gtruth.size(0), 1) |
|
|
|
|
|
tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence) |
|
gtruth = tmp_.detach() |
|
|
|
loss = torch.sum(self.criterion_(scores, gtruth), dim=1) |
|
return loss |
|
|
|
def forward(self, output_x, output_y, target, target_weight): |
|
batch_size = output_x.size(0) |
|
num_joints = output_x.size(1) |
|
loss = 0 |
|
|
|
for idx in range(num_joints): |
|
coord_x_pred = output_x[:,idx] |
|
coord_y_pred = output_y[:,idx] |
|
|
|
coord_gt = target[:,idx] |
|
weight = target_weight[:,idx] |
|
|
|
loss += self.criterion(coord_x_pred,coord_gt[:,0]).mul(weight).sum() |
|
loss += self.criterion(coord_y_pred,coord_gt[:,1]).mul(weight).sum() |
|
return loss / batch_size |