class DistillationLoss(nn.Module): """ 蒸馏损失 = α * CrossEntropy(student_logits, targets) + β * KLDiv(student_logits, teacher_logits) """ def __init__(self, alpha=0.7, beta=0.3, temperature=4.0, reduction='batchmean'): super().__init__() self.alpha = alpha self.beta = beta self.temperature = temperature self.reduction = reduction def forward(self, student_logits, teacher_logits, targets): ce_loss = F.cross_entropy(student_logits, targets) kl_loss = F.kl_div( F.log_softmax(student_logits / self.temperature, dim=1), F.softmax(teacher_logits / self.temperature, dim=1), reduction=self.reduction ) * (self.temperature ** 2) return self.alpha * ce_loss + self.beta * kl_loss import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, logits, targets): ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=self.alpha) pt = torch.exp(-ce_loss) loss = ((1 - pt) ** self.gamma) * ce_loss if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() else: return loss