import torch import torch.nn as nn from transformers import AutoModel, AutoConfig from transformers.modeling_utils import PreTrainedModel from transformers import PretrainedConfig class CustomConfig(PretrainedConfig): model_type = "roberta" def __init__( self, num_classes: int = 10, **kwargs, ): self.num_classes = num_classes super().__init__(**kwargs) # ==================================================== # Model # ==================================================== # class MeanPooling(nn.Module): class MeanPooling(PreTrainedModel): def __init__( self, config # **kwargs, ): super(MeanPooling, self).__init__(config) def forward(self, last_hidden_state, attention_mask): input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() ) sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) sum_mask = torch.clamp(sum_mask, min=1e-9) mean_embeddings = sum_embeddings / sum_mask return mean_embeddings # class CustomModel(nn.Module): class CustomModel(PreTrainedModel): config_class = CustomConfig def __init__( self, cfg, num_labels=10, config_path=None, pretrained=True, binary_classification=False, **kwargs, ): # super().__init__() self.cfg = cfg self.num_labels = num_labels if config_path is None: self.config = AutoConfig.from_pretrained( self.cfg.model_name, output_hidden_states=True ) else: self.config = torch.load(config_path) super().__init__(self.config) if pretrained: self.model = AutoModel.from_pretrained( self.cfg.model_name, config=self.config ) else: self.model = AutoModel(self.config) if self.cfg.gradient_checkpointing: self.model.gradient_checkpointing_enable() self.pool = MeanPooling(config=self.config) self.binary_classification = binary_classification if self.binary_classification: # for binary classification we only want to output a single value self.fc = nn.Linear(self.config.hidden_size, self.num_labels - 1) else: self.fc = nn.Linear(self.config.hidden_size, self.num_labels) self._init_weights(self.fc) self.sigmoid_fn = nn.Sigmoid() def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def feature(self, input_ids, attention_mask, token_type_ids): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) last_hidden_states = outputs[0] feature = self.pool(last_hidden_states, attention_mask) return feature def forward(self, input_ids, attention_mask, token_type_ids): feature = self.feature(input_ids, attention_mask, token_type_ids) output = self.fc(feature) if self.binary_classification: # for binary classification we have to use Sigmoid Function # https://towardsdatascience.com/sigmoid-and-softmax-functions-in-5-minutes-f516c80ea1f9 # https://towardsdatascience.com/bert-to-the-rescue-17671379687f output = self.sigmoid_fn(output) return output