promptpix / model /loss.py
RingL's picture
updated gradio application
4258b21
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class JointsMSELoss(nn.Module):
def __init__(self, use_target_weight):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze()
heatmap_gt = heatmaps_gt[idx].squeeze()
if self.use_target_weight:
loss += 0.5 * self.criterion(
heatmap_pred.mul(target_weight[:, idx]),
heatmap_gt.mul(target_weight[:, idx])
)
else:
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints
class NMTCritierion(nn.Module):
def __init__(self, label_smoothing=0.0):
super(NMTCritierion, self).__init__()
self.label_smoothing = label_smoothing
self.LogSoftmax = nn.LogSoftmax(dim=1) #[B,LOGITS]
if label_smoothing > 0:
self.criterion_ = nn.KLDivLoss(reduction='none')
else:
self.criterion_ = nn.NLLLoss(reduction='none', ignore_index=100000)
self.confidence = 1.0 - label_smoothing
def _smooth_label(self, num_tokens):
one_hot = torch.randn(1, num_tokens)
one_hot.fill_(self.label_smoothing / (num_tokens - 1))
return one_hot
def _bottle(self, v):
return v.view(-1, v.size(2))
def criterion(self, dec_outs, labels):
scores = self.LogSoftmax(dec_outs)
num_tokens = scores.size(-1)
# conduct label_smoothing module
gtruth = labels.view(-1)
if self.confidence < 1:
tdata = gtruth.detach()
one_hot = self._smooth_label(num_tokens) # Do label smoothing, shape is [M]
if labels.is_cuda:
one_hot = one_hot.cuda()
tmp_ = one_hot.repeat(gtruth.size(0), 1) # [N, M]
# print(tmp_)
# print(tdata)
tmp_.scatter_(1, tdata.unsqueeze(1), self.confidence) # after tdata.unsqueeze(1) , tdata shape is [N,1]
gtruth = tmp_.detach()
# print(scores.shape,gtruth.shape)
loss = torch.sum(self.criterion_(scores, gtruth), dim=1)
return loss
def forward(self, output_x, output_y, target, target_weight):
batch_size = output_x.size(0)
num_joints = output_x.size(1)
loss = 0
# print(output_x.shape,output_y.shape,target.shape)
for idx in range(num_joints):
coord_x_pred = output_x[:,idx]
coord_y_pred = output_y[:,idx]
# print(target_weight.shape)
coord_gt = target[:,idx]
weight = target_weight[:,idx]
loss += self.criterion(coord_x_pred,coord_gt[:,0]).mul(weight).sum()
loss += self.criterion(coord_y_pred,coord_gt[:,1]).mul(weight).sum()
return loss / batch_size