File size: 922 Bytes
668bf5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import Trainer
import torch
import torch.nn as nn

"""



cross entropy loss는 [B,C] , [B] shape의 logits와 Labels를 받는다.



"""

class MyTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):  # ← **kwargs 추가
        # 1. inputs에서 라벨 추출
        labels = inputs.pop("labels")

        # 2. 모델 forward
        outputs = model(**inputs)
        logits = outputs.logits  # 예: language modeling일 경우

        # 3. Loss 함수 정의 및 계산
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100)  # 일반적으로 -100은 패딩 토큰 무시
        loss = loss_fct(
            logits.view(-1, logits.size(-1)),  # [batch*seq_len, vocab]
            labels.view(-1)                    # [batch*seq_len]
        )

        # 4. loss 반환
        return (loss, outputs) if return_outputs else loss