Spaces:
Sleeping
Sleeping
File size: 10,385 Bytes
9b0d6c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import torch
import torch.nn as nn
from torch.hub import download_url_to_file
from config import RESOURCES_FOLDER, CHECKPOINT_URLS
from models.seq_models import BidirectionalLSTM, BidirectionalGRU
class PredictionsWrapper(nn.Module):
"""
A wrapper module that adds an optional sequence model and classification heads on top of a transformer.
It implements equations (1), (2), and (3) in the paper.
Args:
base_model (BaseModelWrapper): The base model (transformer) providing sequence embeddings
checkpoint (str, optional): checkpoint name for loading pre-trained weights. Default is None.
n_classes_strong (int): Number of classes for strong predictions. Default is 447.
n_classes_weak (int, optional): Number of classes for weak predictions. Default is None,
which sets it equal to n_classes_strong.
embed_dim (int, optional): Embedding dimension of the base model output. Default is 768.
seq_len (int, optional): Desired sequence length. Default is 250 (40 ms resolution).
seq_model_type (str, optional): Type of sequence model to use.
Default is None, which means no additional sequence model is used.
head_type (str, optional): Type of classification head. Choices are ["linear", "attention", "None"].
Default is "linear". "None" means that sequence embeddings are returned.
rnn_layers (int, optional): Number of RNN layers if seq_model_type is "rnn". Default is 2.
rnn_type (str, optional): Type of RNN to use. Choices are ["BiGRU", "BiLSTM"]. Default is "BiGRU".
rnn_dim (int, optional): Dimension of RNN hidden state if seq_model_type is "rnn". Default is 256.
rnn_dropout (float, optional): Dropout rate for RNN layers. Default is 0.0.
"""
def __init__(self,
base_model,
checkpoint=None,
n_classes_strong=447,
n_classes_weak=None,
embed_dim=768,
seq_len=250,
seq_model_type=None,
head_type="linear",
rnn_layers=2,
rnn_type="BiGRU",
rnn_dim=2048,
rnn_dropout=0.0
):
super(PredictionsWrapper, self).__init__()
self.model = base_model
self.seq_len = seq_len
self.embed_dim = embed_dim
self.n_classes_strong = n_classes_strong
self.n_classes_weak = n_classes_weak if n_classes_weak is not None else n_classes_strong
self.seq_model_type = seq_model_type
self.head_type = head_type
if self.seq_model_type == "rnn":
if rnn_type == "BiGRU":
self.seq_model = BidirectionalGRU(
n_in=self.embed_dim,
n_hidden=rnn_dim,
dropout=rnn_dropout,
num_layers=rnn_layers
)
elif rnn_type == "BiLSTM":
self.seq_model = BidirectionalLSTM(
nIn=self.embed_dim,
nHidden=rnn_dim,
nOut=rnn_dim * 2,
dropout=rnn_dropout,
num_layers=rnn_layers
)
num_features = rnn_dim * 2
elif self.seq_model_type is None:
self.seq_model = nn.Identity()
# no additional sequence model
num_features = self.embed_dim
else:
raise ValueError(f"Unknown seq_model_type: {self.seq_model_type}")
if self.head_type == "attention":
assert self.n_classes_strong == self.n_classes_weak, "head_type=='attention' requires number of strong and " \
"weak classes to be the same!"
if self.head_type is not None:
self.strong_head = nn.Linear(num_features, self.n_classes_strong)
self.weak_head = nn.Linear(num_features, self.n_classes_weak)
if checkpoint is not None:
print("Loading pretrained checkpoint: ", checkpoint)
self.load_checkpoint(checkpoint)
def load_checkpoint(self, checkpoint):
ckpt_file = os.path.join(RESOURCES_FOLDER, checkpoint + ".pt")
if not os.path.exists(ckpt_file):
download_url_to_file(CHECKPOINT_URLS[checkpoint], ckpt_file)
state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True)
# compatibility with uniform wrapper structure we introduced for the public repo
if 'fpasst' in checkpoint:
state_dict = {("model.fpasst." + k[len("model."):] if k.startswith("model.")
else k): v for k, v in state_dict.items()}
elif 'M2D' in checkpoint:
state_dict = {("model.m2d." + k[len("model."):] if not k.startswith("model.m2d.") and k.startswith("model.")
else k): v for k, v in state_dict.items()}
elif 'BEATs' in checkpoint:
state_dict = {("model.beats." + k[len("model.model."):] if k.startswith("model.model")
else k): v for k, v in state_dict.items()}
elif 'ASIT' in checkpoint:
state_dict = {("model.asit." + k[len("model."):] if k.startswith("model.")
else k): v for k, v in state_dict.items()}
n_classes_weak_in_sd = state_dict['weak_head.bias'].shape[0] if 'weak_head.bias' in state_dict else -1
n_classes_strong_in_sd = state_dict['strong_head.bias'].shape[0] if 'strong_head.bias' in state_dict else -1
seq_model_in_sd = any(['seq_model.' in key for key in state_dict.keys()])
keys_to_remove = []
strict = True
expected_missing = 0
if self.head_type is None:
# remove all keys related to head
keys_to_remove.append('weak_head.bias')
keys_to_remove.append('weak_head.weight')
keys_to_remove.append('strong_head.bias')
keys_to_remove.append('strong_head.weight')
elif self.seq_model_type is not None and not seq_model_in_sd:
# we want to train a sequence model (e.g., rnn) on top of a
# pre-trained transformer (e.g., AS weak pretrained)
keys_to_remove.append('weak_head.bias')
keys_to_remove.append('weak_head.weight')
keys_to_remove.append('strong_head.bias')
keys_to_remove.append('strong_head.weight')
num_seq_model_keys = len([key for key in self.seq_model.state_dict()])
expected_missing = len(keys_to_remove) + num_seq_model_keys
strict = False
else:
# head type is not None
if n_classes_weak_in_sd != self.n_classes_weak:
# remove weak head from sd
keys_to_remove.append('weak_head.bias')
keys_to_remove.append('weak_head.weight')
strict = False
if n_classes_strong_in_sd != self.n_classes_strong:
# remove strong head from sd
keys_to_remove.append('strong_head.bias')
keys_to_remove.append('strong_head.weight')
strict = False
expected_missing = len(keys_to_remove)
# allow missing mel parameters for compatibility
num_mel_keys = len([key for key in self.state_dict() if 'mel_transform' in key])
if num_mel_keys > 0:
expected_missing += num_mel_keys
strict = False
state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
missing, unexpected = self.load_state_dict(state_dict, strict=strict)
assert len(missing) == expected_missing
assert len(unexpected) == 0
def separate_params(self):
if hasattr(self, "separate_params"):
return self.model.separate_params()
else:
raise NotImplementedError("The base model has no 'separate_params' method!'")
def has_separate_params(self):
return hasattr(self.model, "separate_params")
def mel_forward(self, x):
return self.model.mel_forward(x)
def forward(self, x):
# base model is expected to output a sequence (see Eq. (1) in paper)
# (batch size x sequence length x embedding dimension)
x = self.model(x)
# ATST: x.shape: batch size x 250 x 768
# PaSST: x.shape: batch size x 250 x 768
# ASiT: x.shape: batch size x 497 x 768
# M2D: x.shape: batch size x 62 x 3840
# BEATs: x.shape: batch size x 496 x 768
assert len(x.shape) == 3
if x.size(-2) > self.seq_len:
x = torch.nn.functional.adaptive_avg_pool1d(x.transpose(1, 2), self.seq_len).transpose(1, 2)
elif x.size(-2) < self.seq_len:
x = torch.nn.functional.interpolate(x.transpose(1, 2), size=self.seq_len,
mode='linear').transpose(1, 2)
# Eq. (3) in the paper
# for teachers this is an RNN, for students it is nn.Identity
x = self.seq_model(x)
if self.head_type == "attention":
# attention head to obtain weak from strong predictions
# this is typically used for the DESED task, which requires both
# weak and strong predictions
strong = torch.sigmoid(self.strong_head(x))
sof = torch.softmax(self.weak_head(x), dim=-1)
sof = torch.clamp(sof, min=1e-7, max=1)
weak = (strong * sof).sum(1) / sof.sum(1)
return strong.transpose(1, 2), weak
elif self.head_type == "linear":
# simple linear layers as head (see Eq. (3) in the paper)
# on AudioSet strong, only strong predictions are used
# on AudioSet weak, only weak predictions are used
# why both? because we tried to simultaneously train on AudioSet weak and strong (less successful)
strong = self.strong_head(x)
weak = self.weak_head(x.mean(dim=1))
return strong.transpose(1, 2), weak
else:
# no head means the sequence is returned instead of strong and weak predictions
return x
|