File size: 1,476 Bytes
04103fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class DistillationLoss(nn.Module):
    """
    蒸馏损失 = α * CrossEntropy(student_logits, targets) + β * KLDiv(student_logits, teacher_logits)
    """
    def __init__(self, alpha=0.7, beta=0.3, temperature=4.0, reduction='batchmean'):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature
        self.reduction = reduction

    def forward(self, student_logits, teacher_logits, targets):
        ce_loss = F.cross_entropy(student_logits, targets)
        kl_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1),
            reduction=self.reduction
        ) * (self.temperature ** 2)
        return self.alpha * ce_loss + self.beta * kl_loss
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        loss = ((1 - pt) ** self.gamma) * ce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss