from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, MSELoss from transformers.modeling_outputs import ModelOutput from transformers.models.deberta_v2.modeling_deberta_v2 import ( ContextPooler, DebertaV2Model, DebertaV2PreTrainedModel, StableDropout) @dataclass class SequenceClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None binary_logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): def __init__(self, config): super().__init__(config) self.deberta = DebertaV2Model(config) self.pooler = ContextPooler(config) output_dim = self.pooler.output_dim self.binary_classifier = nn.Linear(output_dim, 1) self.regressor = nn.Linear(output_dim, 1) drop_out = getattr(config, "cls_dropout", None) drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out self.dropout = StableDropout(drop_out) self.post_init() def freeze_embeddings(self) -> None: """Frezees the embedding layer.""" for param in self.deberta.embeddings.parameters(): param.requires_grad = False def get_input_embeddings(self): return self.deberta.get_input_embeddings() def set_input_embeddings(self, new_embeddings): self.deberta.set_input_embeddings(new_embeddings) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.deberta( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) encoder_layer = outputs[0] pooled_output = self.pooler(encoder_layer) pooled_output = self.dropout(pooled_output) binary_logits = self.binary_classifier(pooled_output) logits = self.regressor(pooled_output) loss = None if labels is not None: regression_loss_fct = MSELoss() regression_loss = regression_loss_fct(logits.squeeze(), labels.squeeze().float()) binary_loss_fct = BCEWithLogitsLoss() # We use as binary labels all texts with a score above 3! binary_labels = (labels >= 3).float() classification_loss = binary_loss_fct(binary_logits.squeeze(), binary_labels.squeeze()) loss = regression_loss + classification_loss if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, binary_logits=binary_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions )