File size: 1,868 Bytes
bb8ce2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9313e67
 
a9c0c7a
bb8ce2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import BertTokenizer, BertModel
from transformers import PretrainedConfig, PreTrainedModel
import torch
import torch.nn as nn


class TypeBERTConfig(PretrainedConfig):
    model_type = "type_bert"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.id2label = {
            0: "agent",
            1: "event",
            2: "place",
            3: "item",
            4: "virtual",
            5: "concept"
        }

        self.label2id = {
            "agent": 0,
            "event": 1,
            "place": 2,
            "item": 3,
            "virtual": 4,
            "concept": 5
        }
        
        self.architectures = ['TypeBERTForSequenceClassification']
        self.tokenizer_class = 'bert-base-uncased'



class TypeBERTForSequenceClassification(PreTrainedModel):
    config_class = TypeBERTConfig

    def __init__(self, config):
        super(TypeBERTForSequenceClassification, self).__init__(config)
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        # for param in self.bert.base_model.parameters():
        #     param.requires_grad = False
        #
        # self.bert.eval()

        self.tanh = nn.Tanh()

        self.dff = nn.Sequential(
            nn.Linear(768, 2048),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 6),
            nn.LogSoftmax(dim=1)
        )

        self.eval()

    def forward(self, **kwargs):

        a = kwargs['attention_mask']
        embs = self.bert(**kwargs)['last_hidden_state']

        embs *= a.unsqueeze(2)
        out = embs.sum(dim=1) / a.sum(dim=1, keepdims=True)
        return {'logits': self.dff(self.tanh(out))}