Spaces:
Sleeping
Sleeping
class DistillationLoss(nn.Module): | |
""" | |
蒸馏损失 = α * CrossEntropy(student_logits, targets) + β * KLDiv(student_logits, teacher_logits) | |
""" | |
def __init__(self, alpha=0.7, beta=0.3, temperature=4.0, reduction='batchmean'): | |
super().__init__() | |
self.alpha = alpha | |
self.beta = beta | |
self.temperature = temperature | |
self.reduction = reduction | |
def forward(self, student_logits, teacher_logits, targets): | |
ce_loss = F.cross_entropy(student_logits, targets) | |
kl_loss = F.kl_div( | |
F.log_softmax(student_logits / self.temperature, dim=1), | |
F.softmax(teacher_logits / self.temperature, dim=1), | |
reduction=self.reduction | |
) * (self.temperature ** 2) | |
return self.alpha * ce_loss + self.beta * kl_loss | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FocalLoss(nn.Module): | |
def __init__(self, alpha=None, gamma=2.0, reduction='mean'): | |
super().__init__() | |
self.alpha = alpha | |
self.gamma = gamma | |
self.reduction = reduction | |
def forward(self, logits, targets): | |
ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=self.alpha) | |
pt = torch.exp(-ce_loss) | |
loss = ((1 - pt) ** self.gamma) * ce_loss | |
if self.reduction == 'mean': | |
return loss.mean() | |
elif self.reduction == 'sum': | |
return loss.sum() | |
else: | |
return loss | |