from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable class JointsMSELoss(nn.Module): def __init__(self, use_target_weight): super(JointsMSELoss, self).__init__() self.criterion = nn.MSELoss(reduction='mean') self.use_target_weight = use_target_weight def forward(self, output, target, target_weight): batch_size = output.size(0) num_joints = output.size(1) heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1) heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) loss = 0 for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx].squeeze() heatmap_gt = heatmaps_gt[idx].squeeze() if self.use_target_weight: loss += 0.5 * self.criterion( heatmap_pred.mul(target_weight[:, idx]), heatmap_gt.mul(target_weight[:, idx]) ) else: loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) return loss / num_joints class NMTCritierion(nn.Module): def __init__(self, label_smoothing=0.0): super(NMTCritierion, self).__init__() self.label_smoothing = label_smoothing self.LogSoftmax = nn.LogSoftmax(dim=1) #[B,LOGITS] if label_smoothing > 0: self.criterion_ = nn.KLDivLoss(reduction='none') else: self.criterion_ = nn.NLLLoss(reduction='none', ignore_index=100000) self.confidence = 1.0 - label_smoothing def _smooth_label(self, num_tokens): one_hot = torch.randn(1, num_tokens) one_hot.fill_(self.label_smoothing / (num_tokens - 1)) return one_hot def _bottle(self, v): return v.view(-1, v.size(2)) def criterion(self, dec_outs, labels): scores = self.LogSoftmax(dec_outs) num_tokens = scores.size(-1) # conduct label_smoothing module gtruth = labels.view(-1) if self.confidence < 1: tdata = gtruth.detach() one_hot = self._smooth_label(num_tokens) # Do label smoothing, shape is [M] if labels.is_cuda: one_hot = one_hot.cuda() tmp_ = one_hot.repeat(gtruth.size(0), 1) # [N, M] # print(tmp_) # print(tdata) tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence) # after tdata.unsqueeze(1) , tdata shape is [N,1] gtruth = tmp_.detach() # print(scores.shape,gtruth.shape) loss = torch.sum(self.criterion_(scores, gtruth), dim=1) return loss def forward(self, output_x, output_y, target, target_weight): batch_size = output_x.size(0) num_joints = output_x.size(1) loss = 0 # print(output_x.shape,output_y.shape,target.shape) for idx in range(num_joints): coord_x_pred = output_x[:,idx] coord_y_pred = output_y[:,idx] # print(target_weight.shape) coord_gt = target[:,idx] weight = target_weight[:,idx] loss += self.criterion(coord_x_pred,coord_gt[:,0]).mul(weight).sum() loss += self.criterion(coord_y_pred,coord_gt[:,1]).mul(weight).sum() return loss / batch_size