Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torch.nn as nn | |
from torch.hub import download_url_to_file | |
from config import RESOURCES_FOLDER, CHECKPOINT_URLS | |
from models.seq_models import BidirectionalLSTM, BidirectionalGRU | |
class PredictionsWrapper(nn.Module): | |
""" | |
A wrapper module that adds an optional sequence model and classification heads on top of a transformer. | |
It implements equations (1), (2), and (3) in the paper. | |
Args: | |
base_model (BaseModelWrapper): The base model (transformer) providing sequence embeddings | |
checkpoint (str, optional): checkpoint name for loading pre-trained weights. Default is None. | |
n_classes_strong (int): Number of classes for strong predictions. Default is 447. | |
n_classes_weak (int, optional): Number of classes for weak predictions. Default is None, | |
which sets it equal to n_classes_strong. | |
embed_dim (int, optional): Embedding dimension of the base model output. Default is 768. | |
seq_len (int, optional): Desired sequence length. Default is 250 (40 ms resolution). | |
seq_model_type (str, optional): Type of sequence model to use. | |
Default is None, which means no additional sequence model is used. | |
head_type (str, optional): Type of classification head. Choices are ["linear", "attention", "None"]. | |
Default is "linear". "None" means that sequence embeddings are returned. | |
rnn_layers (int, optional): Number of RNN layers if seq_model_type is "rnn". Default is 2. | |
rnn_type (str, optional): Type of RNN to use. Choices are ["BiGRU", "BiLSTM"]. Default is "BiGRU". | |
rnn_dim (int, optional): Dimension of RNN hidden state if seq_model_type is "rnn". Default is 256. | |
rnn_dropout (float, optional): Dropout rate for RNN layers. Default is 0.0. | |
""" | |
def __init__(self, | |
base_model, | |
checkpoint=None, | |
n_classes_strong=447, | |
n_classes_weak=None, | |
embed_dim=768, | |
seq_len=250, | |
seq_model_type=None, | |
head_type="linear", | |
rnn_layers=2, | |
rnn_type="BiGRU", | |
rnn_dim=2048, | |
rnn_dropout=0.0 | |
): | |
super(PredictionsWrapper, self).__init__() | |
self.model = base_model | |
self.seq_len = seq_len | |
self.embed_dim = embed_dim | |
self.n_classes_strong = n_classes_strong | |
self.n_classes_weak = n_classes_weak if n_classes_weak is not None else n_classes_strong | |
self.seq_model_type = seq_model_type | |
self.head_type = head_type | |
if self.seq_model_type == "rnn": | |
if rnn_type == "BiGRU": | |
self.seq_model = BidirectionalGRU( | |
n_in=self.embed_dim, | |
n_hidden=rnn_dim, | |
dropout=rnn_dropout, | |
num_layers=rnn_layers | |
) | |
elif rnn_type == "BiLSTM": | |
self.seq_model = BidirectionalLSTM( | |
nIn=self.embed_dim, | |
nHidden=rnn_dim, | |
nOut=rnn_dim * 2, | |
dropout=rnn_dropout, | |
num_layers=rnn_layers | |
) | |
num_features = rnn_dim * 2 | |
elif self.seq_model_type is None: | |
self.seq_model = nn.Identity() | |
# no additional sequence model | |
num_features = self.embed_dim | |
else: | |
raise ValueError(f"Unknown seq_model_type: {self.seq_model_type}") | |
if self.head_type == "attention": | |
assert self.n_classes_strong == self.n_classes_weak, "head_type=='attention' requires number of strong and " \ | |
"weak classes to be the same!" | |
if self.head_type is not None: | |
self.strong_head = nn.Linear(num_features, self.n_classes_strong) | |
self.weak_head = nn.Linear(num_features, self.n_classes_weak) | |
if checkpoint is not None: | |
print("Loading pretrained checkpoint: ", checkpoint) | |
self.load_checkpoint(checkpoint) | |
def load_checkpoint(self, checkpoint): | |
ckpt_file = os.path.join(RESOURCES_FOLDER, checkpoint + ".pt") | |
if not os.path.exists(ckpt_file): | |
download_url_to_file(CHECKPOINT_URLS[checkpoint], ckpt_file) | |
state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True) | |
# compatibility with uniform wrapper structure we introduced for the public repo | |
if 'fpasst' in checkpoint: | |
state_dict = {("model.fpasst." + k[len("model."):] if k.startswith("model.") | |
else k): v for k, v in state_dict.items()} | |
elif 'M2D' in checkpoint: | |
state_dict = {("model.m2d." + k[len("model."):] if not k.startswith("model.m2d.") and k.startswith("model.") | |
else k): v for k, v in state_dict.items()} | |
elif 'BEATs' in checkpoint: | |
state_dict = {("model.beats." + k[len("model.model."):] if k.startswith("model.model") | |
else k): v for k, v in state_dict.items()} | |
elif 'ASIT' in checkpoint: | |
state_dict = {("model.asit." + k[len("model."):] if k.startswith("model.") | |
else k): v for k, v in state_dict.items()} | |
n_classes_weak_in_sd = state_dict['weak_head.bias'].shape[0] if 'weak_head.bias' in state_dict else -1 | |
n_classes_strong_in_sd = state_dict['strong_head.bias'].shape[0] if 'strong_head.bias' in state_dict else -1 | |
seq_model_in_sd = any(['seq_model.' in key for key in state_dict.keys()]) | |
keys_to_remove = [] | |
strict = True | |
expected_missing = 0 | |
if self.head_type is None: | |
# remove all keys related to head | |
keys_to_remove.append('weak_head.bias') | |
keys_to_remove.append('weak_head.weight') | |
keys_to_remove.append('strong_head.bias') | |
keys_to_remove.append('strong_head.weight') | |
elif self.seq_model_type is not None and not seq_model_in_sd: | |
# we want to train a sequence model (e.g., rnn) on top of a | |
# pre-trained transformer (e.g., AS weak pretrained) | |
keys_to_remove.append('weak_head.bias') | |
keys_to_remove.append('weak_head.weight') | |
keys_to_remove.append('strong_head.bias') | |
keys_to_remove.append('strong_head.weight') | |
num_seq_model_keys = len([key for key in self.seq_model.state_dict()]) | |
expected_missing = len(keys_to_remove) + num_seq_model_keys | |
strict = False | |
else: | |
# head type is not None | |
if n_classes_weak_in_sd != self.n_classes_weak: | |
# remove weak head from sd | |
keys_to_remove.append('weak_head.bias') | |
keys_to_remove.append('weak_head.weight') | |
strict = False | |
if n_classes_strong_in_sd != self.n_classes_strong: | |
# remove strong head from sd | |
keys_to_remove.append('strong_head.bias') | |
keys_to_remove.append('strong_head.weight') | |
strict = False | |
expected_missing = len(keys_to_remove) | |
# allow missing mel parameters for compatibility | |
num_mel_keys = len([key for key in self.state_dict() if 'mel_transform' in key]) | |
if num_mel_keys > 0: | |
expected_missing += num_mel_keys | |
strict = False | |
state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove} | |
missing, unexpected = self.load_state_dict(state_dict, strict=strict) | |
assert len(missing) == expected_missing | |
assert len(unexpected) == 0 | |
def separate_params(self): | |
if hasattr(self, "separate_params"): | |
return self.model.separate_params() | |
else: | |
raise NotImplementedError("The base model has no 'separate_params' method!'") | |
def has_separate_params(self): | |
return hasattr(self.model, "separate_params") | |
def mel_forward(self, x): | |
return self.model.mel_forward(x) | |
def forward(self, x): | |
# base model is expected to output a sequence (see Eq. (1) in paper) | |
# (batch size x sequence length x embedding dimension) | |
x = self.model(x) | |
# ATST: x.shape: batch size x 250 x 768 | |
# PaSST: x.shape: batch size x 250 x 768 | |
# ASiT: x.shape: batch size x 497 x 768 | |
# M2D: x.shape: batch size x 62 x 3840 | |
# BEATs: x.shape: batch size x 496 x 768 | |
assert len(x.shape) == 3 | |
if x.size(-2) > self.seq_len: | |
x = torch.nn.functional.adaptive_avg_pool1d(x.transpose(1, 2), self.seq_len).transpose(1, 2) | |
elif x.size(-2) < self.seq_len: | |
x = torch.nn.functional.interpolate(x.transpose(1, 2), size=self.seq_len, | |
mode='linear').transpose(1, 2) | |
# Eq. (3) in the paper | |
# for teachers this is an RNN, for students it is nn.Identity | |
x = self.seq_model(x) | |
if self.head_type == "attention": | |
# attention head to obtain weak from strong predictions | |
# this is typically used for the DESED task, which requires both | |
# weak and strong predictions | |
strong = torch.sigmoid(self.strong_head(x)) | |
sof = torch.softmax(self.weak_head(x), dim=-1) | |
sof = torch.clamp(sof, min=1e-7, max=1) | |
weak = (strong * sof).sum(1) / sof.sum(1) | |
return strong.transpose(1, 2), weak | |
elif self.head_type == "linear": | |
# simple linear layers as head (see Eq. (3) in the paper) | |
# on AudioSet strong, only strong predictions are used | |
# on AudioSet weak, only weak predictions are used | |
# why both? because we tried to simultaneously train on AudioSet weak and strong (less successful) | |
strong = self.strong_head(x) | |
weak = self.weak_head(x.mean(dim=1)) | |
return strong.transpose(1, 2), weak | |
else: | |
# no head means the sequence is returned instead of strong and weak predictions | |
return x | |