File size: 4,082 Bytes
99d9876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig


class CustomConfig(PretrainedConfig):
    model_type = "roberta"

    def __init__(
        self,
        num_classes: int = 10,
        **kwargs,
    ):
        self.num_classes = num_classes
        super().__init__(**kwargs)


# ====================================================
# Model
# ====================================================
# class MeanPooling(nn.Module):
class MeanPooling(PreTrainedModel):
    def __init__(
        self,
        config
        # **kwargs,
    ):
        super(MeanPooling, self).__init__(config)

    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        )
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings


# class CustomModel(nn.Module):
class CustomModel(PreTrainedModel):
    config_class = CustomConfig

    def __init__(
        self,
        cfg,
        num_labels=10,
        config_path=None,
        pretrained=True,
        binary_classification=False,
        **kwargs,
    ):
        # super().__init__()
        self.cfg = cfg
        self.num_labels = num_labels
        if config_path is None:
            self.config = AutoConfig.from_pretrained(
                self.cfg.model_name, output_hidden_states=True
            )
        else:
            self.config = torch.load(config_path)

        super().__init__(self.config)

        if pretrained:
            self.model = AutoModel.from_pretrained(
                self.cfg.model_name, config=self.config
            )
        else:
            self.model = AutoModel(self.config)

        if self.cfg.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

        self.pool = MeanPooling(config=self.config)

        self.binary_classification = binary_classification

        if self.binary_classification:
            # for binary classification we only want to output a single value
            self.fc = nn.Linear(self.config.hidden_size, self.num_labels - 1)
        else:
            self.fc = nn.Linear(self.config.hidden_size, self.num_labels)

        self._init_weights(self.fc)

        self.sigmoid_fn = nn.Sigmoid()

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def feature(self, input_ids, attention_mask, token_type_ids):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        last_hidden_states = outputs[0]
        feature = self.pool(last_hidden_states, attention_mask)
        return feature

    def forward(self, input_ids, attention_mask, token_type_ids):
        feature = self.feature(input_ids, attention_mask, token_type_ids)
        output = self.fc(feature)
        if self.binary_classification:
            # for binary classification we have to use Sigmoid Function
            # https://towardsdatascience.com/sigmoid-and-softmax-functions-in-5-minutes-f516c80ea1f9
            # https://towardsdatascience.com/bert-to-the-rescue-17671379687f
            output = self.sigmoid_fn(output)

        return output