vfc-type / custom_models.py
tdunlap607's picture
Fold 0 Epoch 8 Initial Push
99d9876
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig
class CustomConfig(PretrainedConfig):
model_type = "roberta"
def __init__(
self,
num_classes: int = 10,
**kwargs,
):
self.num_classes = num_classes
super().__init__(**kwargs)
# ====================================================
# Model
# ====================================================
# class MeanPooling(nn.Module):
class MeanPooling(PreTrainedModel):
def __init__(
self,
config
# **kwargs,
):
super(MeanPooling, self).__init__(config)
def forward(self, last_hidden_state, attention_mask):
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
)
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
mean_embeddings = sum_embeddings / sum_mask
return mean_embeddings
# class CustomModel(nn.Module):
class CustomModel(PreTrainedModel):
config_class = CustomConfig
def __init__(
self,
cfg,
num_labels=10,
config_path=None,
pretrained=True,
binary_classification=False,
**kwargs,
):
# super().__init__()
self.cfg = cfg
self.num_labels = num_labels
if config_path is None:
self.config = AutoConfig.from_pretrained(
self.cfg.model_name, output_hidden_states=True
)
else:
self.config = torch.load(config_path)
super().__init__(self.config)
if pretrained:
self.model = AutoModel.from_pretrained(
self.cfg.model_name, config=self.config
)
else:
self.model = AutoModel(self.config)
if self.cfg.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.pool = MeanPooling(config=self.config)
self.binary_classification = binary_classification
if self.binary_classification:
# for binary classification we only want to output a single value
self.fc = nn.Linear(self.config.hidden_size, self.num_labels - 1)
else:
self.fc = nn.Linear(self.config.hidden_size, self.num_labels)
self._init_weights(self.fc)
self.sigmoid_fn = nn.Sigmoid()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def feature(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
last_hidden_states = outputs[0]
feature = self.pool(last_hidden_states, attention_mask)
return feature
def forward(self, input_ids, attention_mask, token_type_ids):
feature = self.feature(input_ids, attention_mask, token_type_ids)
output = self.fc(feature)
if self.binary_classification:
# for binary classification we have to use Sigmoid Function
# https://towardsdatascience.com/sigmoid-and-softmax-functions-in-5-minutes-f516c80ea1f9
# https://towardsdatascience.com/bert-to-the-rescue-17671379687f
output = self.sigmoid_fn(output)
return output