Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmpose.registry import MODELS | |
class KDLoss(nn.Module): | |
"""PyTorch version of logit-based distillation from DWPose Modified from | |
the official implementation. | |
<https://github.com/IDEA-Research/DWPose> | |
Args: | |
weight (float, optional): Weight of dis_loss. Defaults to 1.0 | |
""" | |
def __init__( | |
self, | |
name, | |
use_this, | |
weight=1.0, | |
): | |
super(KDLoss, self).__init__() | |
self.log_softmax = nn.LogSoftmax(dim=1) | |
self.kl_loss = nn.KLDivLoss(reduction='none') | |
self.weight = weight | |
def forward(self, pred, pred_t, beta, target_weight): | |
ls_x, ls_y = pred | |
lt_x, lt_y = pred_t | |
lt_x = lt_x.detach() | |
lt_y = lt_y.detach() | |
num_joints = ls_x.size(1) | |
loss = 0 | |
loss += (self.loss(ls_x, lt_x, beta, target_weight)) | |
loss += (self.loss(ls_y, lt_y, beta, target_weight)) | |
return loss / num_joints | |
def loss(self, logit_s, logit_t, beta, weight): | |
N = logit_s.shape[0] | |
if len(logit_s.shape) == 3: | |
K = logit_s.shape[1] | |
logit_s = logit_s.reshape(N * K, -1) | |
logit_t = logit_t.reshape(N * K, -1) | |
# N*W(H) | |
s_i = self.log_softmax(logit_s * beta) | |
t_i = F.softmax(logit_t * beta, dim=1) | |
# kd | |
loss_all = torch.sum(self.kl_loss(s_i, t_i), dim=1) | |
loss_all = loss_all.reshape(N, K).sum(dim=1).mean() | |
loss_all = self.weight * loss_all | |
return loss_all | |