File size: 2,688 Bytes
f2abd03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Any, Optional, Tuple, Union
import torch
import transformers


class DistilBertTransferLearningModel(torch.nn.Module):

    def __init__(
        self,
        pretrained_model: str = "distilbert-base-uncased",
        layers: list[Tuple[str, Optional[list[Any]]]] = [
            ('linear', ['in', 'out']),
            ('softmax'),
        ],
        dim_out: int = 2,
        use_local_file: bool = False,
        device: str = 'cpu',
        state_dict: Optional[Union[str, dict]] = None,
    ):
        super(DistilBertTransferLearningModel, self).__init__()
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            pretrained_model, local_files_only=use_local_file
        )
        self.base_model = transformers.AutoModel.from_pretrained(
            pretrained_model, local_files_only=use_local_file
        )
        clf_layers = []
        for layer in layers:
            layer_type = layer[0] if isinstance(layer, tuple) else layer
            if layer_type == 'linear':
                layer_in, layer_out = [
                    (
                        self.base_model.config.hidden_size
                        if x == 'in'
                        else dim_out if x == 'out' else x
                    )
                    for x in layer[1]
                ]
                clf_layers.append(torch.nn.Linear(layer_in, layer_out))
            elif layer_type == 'softmax':
                clf_layers.append(torch.nn.Softmax(dim=-1))
        self.clf = torch.nn.Sequential(*clf_layers)

        if state_dict is not None:
            if isinstance(state_dict, str) and state_dict.endswith('.pt'):
                if device == 'cpu':
                    state_dict = torch.load(state_dict, map_location='cpu')
                else:
                    state_dict = torch.load(state_dict)
            self.load_state_dict(state_dict)

    def forward(self, ids: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        y = self.base_model(ids, attention_mask=mask, return_dict=False)[0][:, 0]
        y = self.clf(y)
        return y

    def predict(self, text: str, device: str) -> torch.Tensor:
        encoded = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            return_token_type_ids=False,
            return_attention_mask=True,
            max_length=512,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        with torch.no_grad():
            ids = encoded['input_ids'].to(device)
            mask = encoded['attention_mask'].to(device)
            output = self.forward(ids, mask)
        return output.to(device)