NSynth-5K-pretrained-sed / models /prediction_wrapper.py
sohamc10's picture
gradio app
9b0d6c2
import os
import torch
import torch.nn as nn
from torch.hub import download_url_to_file
from config import RESOURCES_FOLDER, CHECKPOINT_URLS
from models.seq_models import BidirectionalLSTM, BidirectionalGRU
class PredictionsWrapper(nn.Module):
"""
A wrapper module that adds an optional sequence model and classification heads on top of a transformer.
It implements equations (1), (2), and (3) in the paper.
Args:
base_model (BaseModelWrapper): The base model (transformer) providing sequence embeddings
checkpoint (str, optional): checkpoint name for loading pre-trained weights. Default is None.
n_classes_strong (int): Number of classes for strong predictions. Default is 447.
n_classes_weak (int, optional): Number of classes for weak predictions. Default is None,
which sets it equal to n_classes_strong.
embed_dim (int, optional): Embedding dimension of the base model output. Default is 768.
seq_len (int, optional): Desired sequence length. Default is 250 (40 ms resolution).
seq_model_type (str, optional): Type of sequence model to use.
Default is None, which means no additional sequence model is used.
head_type (str, optional): Type of classification head. Choices are ["linear", "attention", "None"].
Default is "linear". "None" means that sequence embeddings are returned.
rnn_layers (int, optional): Number of RNN layers if seq_model_type is "rnn". Default is 2.
rnn_type (str, optional): Type of RNN to use. Choices are ["BiGRU", "BiLSTM"]. Default is "BiGRU".
rnn_dim (int, optional): Dimension of RNN hidden state if seq_model_type is "rnn". Default is 256.
rnn_dropout (float, optional): Dropout rate for RNN layers. Default is 0.0.
"""
def __init__(self,
base_model,
checkpoint=None,
n_classes_strong=447,
n_classes_weak=None,
embed_dim=768,
seq_len=250,
seq_model_type=None,
head_type="linear",
rnn_layers=2,
rnn_type="BiGRU",
rnn_dim=2048,
rnn_dropout=0.0
):
super(PredictionsWrapper, self).__init__()
self.model = base_model
self.seq_len = seq_len
self.embed_dim = embed_dim
self.n_classes_strong = n_classes_strong
self.n_classes_weak = n_classes_weak if n_classes_weak is not None else n_classes_strong
self.seq_model_type = seq_model_type
self.head_type = head_type
if self.seq_model_type == "rnn":
if rnn_type == "BiGRU":
self.seq_model = BidirectionalGRU(
n_in=self.embed_dim,
n_hidden=rnn_dim,
dropout=rnn_dropout,
num_layers=rnn_layers
)
elif rnn_type == "BiLSTM":
self.seq_model = BidirectionalLSTM(
nIn=self.embed_dim,
nHidden=rnn_dim,
nOut=rnn_dim * 2,
dropout=rnn_dropout,
num_layers=rnn_layers
)
num_features = rnn_dim * 2
elif self.seq_model_type is None:
self.seq_model = nn.Identity()
# no additional sequence model
num_features = self.embed_dim
else:
raise ValueError(f"Unknown seq_model_type: {self.seq_model_type}")
if self.head_type == "attention":
assert self.n_classes_strong == self.n_classes_weak, "head_type=='attention' requires number of strong and " \
"weak classes to be the same!"
if self.head_type is not None:
self.strong_head = nn.Linear(num_features, self.n_classes_strong)
self.weak_head = nn.Linear(num_features, self.n_classes_weak)
if checkpoint is not None:
print("Loading pretrained checkpoint: ", checkpoint)
self.load_checkpoint(checkpoint)
def load_checkpoint(self, checkpoint):
ckpt_file = os.path.join(RESOURCES_FOLDER, checkpoint + ".pt")
if not os.path.exists(ckpt_file):
download_url_to_file(CHECKPOINT_URLS[checkpoint], ckpt_file)
state_dict = torch.load(ckpt_file, map_location="cpu", weights_only=True)
# compatibility with uniform wrapper structure we introduced for the public repo
if 'fpasst' in checkpoint:
state_dict = {("model.fpasst." + k[len("model."):] if k.startswith("model.")
else k): v for k, v in state_dict.items()}
elif 'M2D' in checkpoint:
state_dict = {("model.m2d." + k[len("model."):] if not k.startswith("model.m2d.") and k.startswith("model.")
else k): v for k, v in state_dict.items()}
elif 'BEATs' in checkpoint:
state_dict = {("model.beats." + k[len("model.model."):] if k.startswith("model.model")
else k): v for k, v in state_dict.items()}
elif 'ASIT' in checkpoint:
state_dict = {("model.asit." + k[len("model."):] if k.startswith("model.")
else k): v for k, v in state_dict.items()}
n_classes_weak_in_sd = state_dict['weak_head.bias'].shape[0] if 'weak_head.bias' in state_dict else -1
n_classes_strong_in_sd = state_dict['strong_head.bias'].shape[0] if 'strong_head.bias' in state_dict else -1
seq_model_in_sd = any(['seq_model.' in key for key in state_dict.keys()])
keys_to_remove = []
strict = True
expected_missing = 0
if self.head_type is None:
# remove all keys related to head
keys_to_remove.append('weak_head.bias')
keys_to_remove.append('weak_head.weight')
keys_to_remove.append('strong_head.bias')
keys_to_remove.append('strong_head.weight')
elif self.seq_model_type is not None and not seq_model_in_sd:
# we want to train a sequence model (e.g., rnn) on top of a
# pre-trained transformer (e.g., AS weak pretrained)
keys_to_remove.append('weak_head.bias')
keys_to_remove.append('weak_head.weight')
keys_to_remove.append('strong_head.bias')
keys_to_remove.append('strong_head.weight')
num_seq_model_keys = len([key for key in self.seq_model.state_dict()])
expected_missing = len(keys_to_remove) + num_seq_model_keys
strict = False
else:
# head type is not None
if n_classes_weak_in_sd != self.n_classes_weak:
# remove weak head from sd
keys_to_remove.append('weak_head.bias')
keys_to_remove.append('weak_head.weight')
strict = False
if n_classes_strong_in_sd != self.n_classes_strong:
# remove strong head from sd
keys_to_remove.append('strong_head.bias')
keys_to_remove.append('strong_head.weight')
strict = False
expected_missing = len(keys_to_remove)
# allow missing mel parameters for compatibility
num_mel_keys = len([key for key in self.state_dict() if 'mel_transform' in key])
if num_mel_keys > 0:
expected_missing += num_mel_keys
strict = False
state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
missing, unexpected = self.load_state_dict(state_dict, strict=strict)
assert len(missing) == expected_missing
assert len(unexpected) == 0
def separate_params(self):
if hasattr(self, "separate_params"):
return self.model.separate_params()
else:
raise NotImplementedError("The base model has no 'separate_params' method!'")
def has_separate_params(self):
return hasattr(self.model, "separate_params")
def mel_forward(self, x):
return self.model.mel_forward(x)
def forward(self, x):
# base model is expected to output a sequence (see Eq. (1) in paper)
# (batch size x sequence length x embedding dimension)
x = self.model(x)
# ATST: x.shape: batch size x 250 x 768
# PaSST: x.shape: batch size x 250 x 768
# ASiT: x.shape: batch size x 497 x 768
# M2D: x.shape: batch size x 62 x 3840
# BEATs: x.shape: batch size x 496 x 768
assert len(x.shape) == 3
if x.size(-2) > self.seq_len:
x = torch.nn.functional.adaptive_avg_pool1d(x.transpose(1, 2), self.seq_len).transpose(1, 2)
elif x.size(-2) < self.seq_len:
x = torch.nn.functional.interpolate(x.transpose(1, 2), size=self.seq_len,
mode='linear').transpose(1, 2)
# Eq. (3) in the paper
# for teachers this is an RNN, for students it is nn.Identity
x = self.seq_model(x)
if self.head_type == "attention":
# attention head to obtain weak from strong predictions
# this is typically used for the DESED task, which requires both
# weak and strong predictions
strong = torch.sigmoid(self.strong_head(x))
sof = torch.softmax(self.weak_head(x), dim=-1)
sof = torch.clamp(sof, min=1e-7, max=1)
weak = (strong * sof).sum(1) / sof.sum(1)
return strong.transpose(1, 2), weak
elif self.head_type == "linear":
# simple linear layers as head (see Eq. (3) in the paper)
# on AudioSet strong, only strong predictions are used
# on AudioSet weak, only weak predictions are used
# why both? because we tried to simultaneously train on AudioSet weak and strong (less successful)
strong = self.strong_head(x)
weak = self.weak_head(x.mean(dim=1))
return strong.transpose(1, 2), weak
else:
# no head means the sequence is returned instead of strong and weak predictions
return x