|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModel, AutoConfig |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers import PretrainedConfig |
|
|
|
|
|
class CustomConfig(PretrainedConfig): |
|
model_type = "roberta" |
|
|
|
def __init__( |
|
self, |
|
num_classes: int = 10, |
|
**kwargs, |
|
): |
|
self.num_classes = num_classes |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class MeanPooling(PreTrainedModel): |
|
def __init__( |
|
self, |
|
config |
|
|
|
): |
|
super(MeanPooling, self).__init__(config) |
|
|
|
def forward(self, last_hidden_state, attention_mask): |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
) |
|
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1) |
|
sum_mask = input_mask_expanded.sum(1) |
|
sum_mask = torch.clamp(sum_mask, min=1e-9) |
|
mean_embeddings = sum_embeddings / sum_mask |
|
return mean_embeddings |
|
|
|
|
|
|
|
class CustomModel(PreTrainedModel): |
|
config_class = CustomConfig |
|
|
|
def __init__( |
|
self, |
|
cfg, |
|
num_labels=10, |
|
config_path=None, |
|
pretrained=True, |
|
binary_classification=False, |
|
**kwargs, |
|
): |
|
|
|
self.cfg = cfg |
|
self.num_labels = num_labels |
|
if config_path is None: |
|
self.config = AutoConfig.from_pretrained( |
|
self.cfg.model_name, output_hidden_states=True |
|
) |
|
else: |
|
self.config = torch.load(config_path) |
|
|
|
super().__init__(self.config) |
|
|
|
if pretrained: |
|
self.model = AutoModel.from_pretrained( |
|
self.cfg.model_name, config=self.config |
|
) |
|
else: |
|
self.model = AutoModel(self.config) |
|
|
|
if self.cfg.gradient_checkpointing: |
|
self.model.gradient_checkpointing_enable() |
|
|
|
self.pool = MeanPooling(config=self.config) |
|
|
|
self.binary_classification = binary_classification |
|
|
|
if self.binary_classification: |
|
|
|
self.fc = nn.Linear(self.config.hidden_size, self.num_labels - 1) |
|
else: |
|
self.fc = nn.Linear(self.config.hidden_size, self.num_labels) |
|
|
|
self._init_weights(self.fc) |
|
|
|
self.sigmoid_fn = nn.Sigmoid() |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def feature(self, input_ids, attention_mask, token_type_ids): |
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
) |
|
last_hidden_states = outputs[0] |
|
feature = self.pool(last_hidden_states, attention_mask) |
|
return feature |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
feature = self.feature(input_ids, attention_mask, token_type_ids) |
|
output = self.fc(feature) |
|
if self.binary_classification: |
|
|
|
|
|
|
|
output = self.sigmoid_fn(output) |
|
|
|
return output |
|
|