from transformers import BertTokenizer, BertModel from transformers import PretrainedConfig, PreTrainedModel import torch import torch.nn as nn class TypeBERTConfig(PretrainedConfig): model_type = "type_bert" def __init__(self, **kwargs): super().__init__(**kwargs) self.id2label = { 0: "agent", 1: "event", 2: "place", 3: "item", 4: "virtual", 5: "concept" } self.label2id = { "agent": 0, "event": 1, "place": 2, "item": 3, "virtual": 4, "concept": 5 } self.architectures = ['TypeBERTForSequenceClassification'] self.tokenizer_class = 'bert-base-uncased' class TypeBERTForSequenceClassification(PreTrainedModel): config_class = TypeBERTConfig def __init__(self, config): super(TypeBERTForSequenceClassification, self).__init__(config) self.bert = BertModel.from_pretrained("bert-base-uncased") # for param in self.bert.base_model.parameters(): # param.requires_grad = False # # self.bert.eval() self.tanh = nn.Tanh() self.dff = nn.Sequential( nn.Linear(768, 2048), nn.ReLU(), nn.Dropout(0.1), nn.Linear(2048, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, 64), nn.ReLU(), nn.Dropout(0.1), nn.Linear(64, 6), nn.LogSoftmax(dim=1) ) self.eval() def forward(self, **kwargs): a = kwargs['attention_mask'] embs = self.bert(**kwargs)['last_hidden_state'] embs *= a.unsqueeze(2) out = embs.sum(dim=1) / a.sum(dim=1, keepdims=True) return {'logits': self.dff(self.tanh(out))}