aicomp_demo / utils /losses.py
ceasonen
我的视网膜检测网站
04103fb
class DistillationLoss(nn.Module):
"""
蒸馏损失 = α * CrossEntropy(student_logits, targets) + β * KLDiv(student_logits, teacher_logits)
"""
def __init__(self, alpha=0.7, beta=0.3, temperature=4.0, reduction='batchmean'):
super().__init__()
self.alpha = alpha
self.beta = beta
self.temperature = temperature
self.reduction = reduction
def forward(self, student_logits, teacher_logits, targets):
ce_loss = F.cross_entropy(student_logits, targets)
kl_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1),
reduction=self.reduction
) * (self.temperature ** 2)
return self.alpha * ce_loss + self.beta * kl_loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, targets):
ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=self.alpha)
pt = torch.exp(-ce_loss)
loss = ((1 - pt) ** self.gamma) * ce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss