File size: 10,385 Bytes
9b0d6c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os

import torch
import torch.nn as nn
from torch.hub import download_url_to_file

from config import RESOURCES_FOLDER, CHECKPOINT_URLS
from models.seq_models import BidirectionalLSTM, BidirectionalGRU


class PredictionsWrapper(nn.Module):
    """
        A wrapper module that adds an optional sequence model and classification heads on top of a transformer.
        It implements equations (1), (2), and (3) in the paper.

        Args:
            base_model (BaseModelWrapper): The base model (transformer) providing sequence embeddings
            checkpoint (str, optional): checkpoint name for loading pre-trained weights. Default is None.
            n_classes_strong (int): Number of classes for strong predictions. Default is 447.
            n_classes_weak (int, optional): Number of classes for weak predictions. Default is None,
                                            which sets it equal to n_classes_strong.
            embed_dim (int, optional): Embedding dimension of the base model output. Default is 768.
            seq_len (int, optional): Desired sequence length. Default is 250 (40 ms resolution).
            seq_model_type (str, optional): Type of sequence model to use.
                                            Default is None, which means no additional sequence model is used.
            head_type (str, optional): Type of classification head. Choices are ["linear", "attention", "None"].
                                       Default is "linear". "None" means that sequence embeddings are returned.
            rnn_layers (int, optional): Number of RNN layers if seq_model_type is "rnn". Default is 2.
            rnn_type (str, optional): Type of RNN to use. Choices are ["BiGRU", "BiLSTM"]. Default is "BiGRU".
            rnn_dim (int, optional): Dimension of RNN hidden state if seq_model_type is "rnn". Default is 256.
            rnn_dropout (float, optional): Dropout rate for RNN layers. Default is 0.0.
        """

    def __init__(self,
                 base_model,
                 checkpoint=None,
                 n_classes_strong=447,
                 n_classes_weak=None,
                 embed_dim=768,
                 seq_len=250,
                 seq_model_type=None,
                 head_type="linear",
                 rnn_layers=2,
                 rnn_type="BiGRU",
                 rnn_dim=2048,
                 rnn_dropout=0.0
                 ):
        super(PredictionsWrapper, self).__init__()
        self.model = base_model
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        self.n_classes_strong = n_classes_strong
        self.n_classes_weak = n_classes_weak if n_classes_weak is not None else n_classes_strong
        self.seq_model_type = seq_model_type
        self.head_type = head_type

        if self.seq_model_type == "rnn":
            if rnn_type == "BiGRU":
                self.seq_model = BidirectionalGRU(
                    n_in=self.embed_dim,
                    n_hidden=rnn_dim,
                    dropout=rnn_dropout,
                    num_layers=rnn_layers
                )
            elif rnn_type == "BiLSTM":
                self.seq_model = BidirectionalLSTM(
                    nIn=self.embed_dim,
                    nHidden=rnn_dim,
                    nOut=rnn_dim * 2,
                    dropout=rnn_dropout,
                    num_layers=rnn_layers
                )
            num_features = rnn_dim * 2
        elif self.seq_model_type is None:
            self.seq_model = nn.Identity()
            # no additional sequence model
            num_features = self.embed_dim
        else:
            raise ValueError(f"Unknown seq_model_type: {self.seq_model_type}")

        if self.head_type == "attention":
            assert self.n_classes_strong == self.n_classes_weak, "head_type=='attention' requires number of strong and " \
                                                                 "weak classes to be the same!"

        if self.head_type is not None:
            self.strong_head = nn.Linear(num_features, self.n_classes_strong)
            self.weak_head = nn.Linear(num_features, self.n_classes_weak)
        if checkpoint is not None:
            print("Loading pretrained checkpoint: ", checkpoint)
            self.load_checkpoint(checkpoint)

    def load_checkpoint(self, checkpoint):
        ckpt_file = os.path.join(RESOURCES_FOLDER, checkpoint + ".pt")
        if not os.path.exists(ckpt_file):
            download_url_to_file(CHECKPOINT_URLS[checkpoint], ckpt_file)
        state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True)

        # compatibility with uniform wrapper structure we introduced for the public repo
        if 'fpasst' in checkpoint:
            state_dict = {("model.fpasst." + k[len("model."):] if k.startswith("model.")
                           else k): v for k, v in state_dict.items()}
        elif 'M2D' in checkpoint:
            state_dict = {("model.m2d." + k[len("model."):] if not k.startswith("model.m2d.") and k.startswith("model.")
                           else k): v for k, v in state_dict.items()}
        elif 'BEATs' in checkpoint:
            state_dict = {("model.beats." + k[len("model.model."):] if k.startswith("model.model")
                           else k): v for k, v in state_dict.items()}
        elif 'ASIT' in checkpoint:
            state_dict = {("model.asit." + k[len("model."):] if k.startswith("model.")
                           else k): v for k, v in state_dict.items()}

        n_classes_weak_in_sd = state_dict['weak_head.bias'].shape[0] if 'weak_head.bias' in state_dict else -1
        n_classes_strong_in_sd = state_dict['strong_head.bias'].shape[0] if 'strong_head.bias' in state_dict else -1
        seq_model_in_sd = any(['seq_model.' in key for key in state_dict.keys()])
        keys_to_remove = []
        strict = True
        expected_missing = 0
        if self.head_type is None:
            # remove all keys related to head
            keys_to_remove.append('weak_head.bias')
            keys_to_remove.append('weak_head.weight')
            keys_to_remove.append('strong_head.bias')
            keys_to_remove.append('strong_head.weight')
        elif self.seq_model_type is not None and not seq_model_in_sd:
            # we want to train a sequence model (e.g., rnn) on top of a
            #   pre-trained transformer (e.g., AS weak pretrained)
            keys_to_remove.append('weak_head.bias')
            keys_to_remove.append('weak_head.weight')
            keys_to_remove.append('strong_head.bias')
            keys_to_remove.append('strong_head.weight')
            num_seq_model_keys = len([key for key in self.seq_model.state_dict()])
            expected_missing = len(keys_to_remove) + num_seq_model_keys
            strict = False
        else:
            # head type is not None
            if n_classes_weak_in_sd != self.n_classes_weak:
                # remove weak head from sd
                keys_to_remove.append('weak_head.bias')
                keys_to_remove.append('weak_head.weight')
                strict = False
            if n_classes_strong_in_sd != self.n_classes_strong:
                # remove strong head from sd
                keys_to_remove.append('strong_head.bias')
                keys_to_remove.append('strong_head.weight')
                strict = False
            expected_missing = len(keys_to_remove)

        # allow missing mel parameters for compatibility
        num_mel_keys = len([key for key in self.state_dict() if 'mel_transform' in key])
        if num_mel_keys > 0:
            expected_missing += num_mel_keys
            strict = False

        state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
        missing, unexpected = self.load_state_dict(state_dict, strict=strict)
        assert len(missing) == expected_missing
        assert len(unexpected) == 0

    def separate_params(self):
        if hasattr(self, "separate_params"):
            return self.model.separate_params()
        else:
            raise NotImplementedError("The base model has no 'separate_params' method!'")

    def has_separate_params(self):
        return hasattr(self.model, "separate_params")

    def mel_forward(self, x):
        return self.model.mel_forward(x)

    def forward(self, x):
        # base model is expected to output a sequence (see Eq. (1) in paper)
        # (batch size x sequence length x embedding dimension)
        x = self.model(x)

        # ATST: x.shape: batch size x 250 x 768
        # PaSST: x.shape: batch size x 250 x 768
        # ASiT: x.shape: batch size x 497 x 768
        # M2D: x.shape: batch size x 62 x 3840
        # BEATs: x.shape: batch size x 496 x 768

        assert len(x.shape) == 3

        if x.size(-2) > self.seq_len:
            x = torch.nn.functional.adaptive_avg_pool1d(x.transpose(1, 2), self.seq_len).transpose(1, 2)
        elif x.size(-2) < self.seq_len:
            x = torch.nn.functional.interpolate(x.transpose(1, 2), size=self.seq_len,
                                                mode='linear').transpose(1, 2)

        # Eq. (3) in the paper
        # for teachers this is an RNN, for students it is nn.Identity
        x = self.seq_model(x)

        if self.head_type == "attention":
            # attention head to obtain weak from strong predictions
            # this is typically used for the DESED task, which requires both
            # weak and strong predictions
            strong = torch.sigmoid(self.strong_head(x))
            sof = torch.softmax(self.weak_head(x), dim=-1)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (strong * sof).sum(1) / sof.sum(1)
            return strong.transpose(1, 2), weak
        elif self.head_type == "linear":
            # simple linear layers as head (see Eq. (3) in the paper)
            # on AudioSet strong, only strong predictions are used
            # on AudioSet weak, only weak predictions are used
            # why both? because we tried to simultaneously train on AudioSet weak and strong (less successful)
            strong = self.strong_head(x)
            weak = self.weak_head(x.mean(dim=1))
            return strong.transpose(1, 2), weak
        else:
            # no head means the sequence is returned instead of strong and weak predictions
            return x