from transformers import BertTokenizer, BertModel | |
from transformers import PretrainedConfig, PreTrainedModel | |
import torch | |
import torch.nn as nn | |
class TypeBERTConfig(PretrainedConfig): | |
model_type = "type_bert" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.id2label = { | |
0: "agent", | |
1: "event", | |
2: "place", | |
3: "item", | |
4: "virtual", | |
5: "concept" | |
} | |
self.label2id = { | |
"agent": 0, | |
"event": 1, | |
"place": 2, | |
"item": 3, | |
"virtual": 4, | |
"concept": 5 | |
} | |
self.architectures = ['TypeBERTForSequenceClassification'] | |
self.tokenizer_class = 'bert-base-uncased' | |
class TypeBERTForSequenceClassification(PreTrainedModel): | |
config_class = TypeBERTConfig | |
def __init__(self, config): | |
super(TypeBERTForSequenceClassification, self).__init__(config) | |
self.bert = BertModel.from_pretrained("bert-base-uncased") | |
# for param in self.bert.base_model.parameters(): | |
# param.requires_grad = False | |
# | |
# self.bert.eval() | |
self.tanh = nn.Tanh() | |
self.dff = nn.Sequential( | |
nn.Linear(768, 2048), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(2048, 512), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(512, 64), | |
nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(64, 6), | |
nn.LogSoftmax(dim=1) | |
) | |
self.eval() | |
def forward(self, **kwargs): | |
a = kwargs['attention_mask'] | |
embs = self.bert(**kwargs)['last_hidden_state'] | |
embs *= a.unsqueeze(2) | |
out = embs.sum(dim=1) / a.sum(dim=1, keepdims=True) | |
return {'logits': self.dff(self.tanh(out))} |