File size: 2,098 Bytes
0ef3cc0 ef313a9 0ef3cc0 ef313a9 0ef3cc0 b0ebb4e 0ef3cc0 b0ebb4e 0ef3cc0 b0ebb4e 0ef3cc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# from transformers.models.led.modeling_led import LEDEncoder
from transformers import LEDConfig, LEDModel, LEDPreTrainedModel
import torch.nn as nn
# NEED TO REPLACE nn.Module with PreTrainedModel
class CustomLEDForQAModel(LEDPreTrainedModel):
config_class = LEDConfig
def __init__(self, config: LEDConfig, checkpoint):
super().__init__(config)
config.num_labels = 2
self.num_labels = config.num_labels
if (checkpoint):
self.led = LEDModel.from_pretrained(checkpoint, config=config).get_encoder()
else:
self.led = LEDModel(config).get_encoder()
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, start_positions=None, end_positions=None):
outputs = self.led(input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
logits = self.qa_outputs(outputs.last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
total_loss = None
if start_positions is not None and end_positions is not None:
loss_fct = nn.CrossEntropyLoss()
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
# start_loss = loss_fct(start_logits[index], start_positions[index][0])
# end_loss = loss_fct(end_logits[index], end_positions[index][0])
total_loss = (start_loss + end_loss) / 2
return {
'loss': total_loss,
'start_logits': start_logits,
'end_logits': end_logits,
} |