NSynth-5K-pretrained-sed / ex_audioset_strong.py
sohamc10's picture
gradio app
9b0d6c2
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import argparse
import torch.nn as nn
import wandb
import transformers
import random
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import sed_scores_eval
from helpers.decode import batched_decode_preds
from helpers.encode import ManyHotEncoder
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.asit.ASIT_wrapper import ASiTWrapper
from models.prediction_wrapper import PredictionsWrapper
from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop
from helpers.utils import worker_init_fn
from data_util.audioset_strong import get_training_dataset, get_eval_dataset
from data_util.audioset_strong import get_temporal_count_balanced_sample_weights, get_uniform_sample_weights, \
get_weighted_sampler
from data_util.audioset_classes import as_strong_train_classes, as_strong_eval_classes
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
from models.frame_mn.utils import NAME_TO_WIDTH
class PLModule(pl.LightningModule):
def __init__(self, config, encoder):
super().__init__()
self.config = config
self.encoder = encoder
if config.pretrained == "scratch":
checkpoint = None
elif config.pretrained == "ssl":
checkpoint = "ssl"
elif config.pretrained == "weak":
checkpoint = "weak"
elif config.pretrained == "strong":
checkpoint = "strong_1"
else:
raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}")
# load transformer model
if config.model_name == "BEATs":
beats = BEATsWrapper()
model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type)
elif config.model_name == "ATST-F":
atst = ATSTWrapper()
model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type)
elif config.model_name == "fpasst":
fpasst = FPaSSTWrapper()
model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type)
elif config.model_name == "M2D":
m2d = M2DWrapper()
model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type,
embed_dim=m2d.m2d.cfg.feature_d)
elif config.model_name == "ASIT":
asit = ASiTWrapper()
model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None,
seq_model_type=config.seq_model_type)
elif config.model_name.startswith("frame_mn"):
width = NAME_TO_WIDTH(config.model_name)
frame_mn = FrameMNWrapper(width)
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
model = PredictionsWrapper(frame_mn, checkpoint=f"{config.model_name}_strong_1", embed_dim=embed_dim)
else:
raise NotImplementedError(f"Model {config.model_name} not (yet) implemented")
self.model = model
# prepare ingredients for knowledge distillation
assert 0 <= config.distillation_loss_weight <= 1, "Lambda for Knowledge Distillation must be between 0 and 1."
self.strong_loss = nn.BCEWithLogitsLoss()
self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0))
self.val_durations_df = pd.read_csv(f"resources/eval_durations.csv",
sep=",", header=None, names=["filename", "duration"])
self.val_predictions_strong = {}
self.val_ground_truth = {}
self.val_duration = {}
self.val_loss = []
def forward(self, batch):
x = batch["audio"]
mel = self.model.mel_forward(x)
y_strong, _ = self.model(mel)
return y_strong
def get_optimizer(
self, lr, adamw=False, weight_decay=0.01, betas=(0.9, 0.999)
):
# we split the parameters into two groups, one for the pretrained model and one for the downstream model
# we also split each of them into <=1 dimensional and >=2 dimensional parameters, so we can only
# apply weight decay to the >=2 dimensional parameters, thus excluding biases and batch norms, an idea from NanoGPT
params_leq1D = []
params_geq2D = []
for name, param in self.model.named_parameters():
if param.requires_grad:
if param.ndimension() >= 2:
params_geq2D.append(param)
else:
params_leq1D.append(param)
param_groups = [
{'params': params_leq1D, 'lr': lr},
{'params': params_geq2D, 'lr': lr, 'weight_decay': weight_decay},
]
if weight_decay > 0:
assert adamw
assert len(param_groups) > 0
if adamw:
print(f"\nUsing adamw weight_decay={weight_decay}!\n")
return torch.optim.AdamW(param_groups, lr=lr, betas=betas)
return torch.optim.Adam(param_groups, lr=lr, betas=betas)
def get_lr_scheduler(
self,
optimizer,
num_training_steps,
schedule_mode="cos",
gamma: float = 0.999996,
num_warmup_steps=20000,
lr_end=2e-7,
):
if schedule_mode in {"exp"}:
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
if schedule_mode in {"cosine", "cos"}:
return transformers.get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
if schedule_mode in {"linear"}:
print("Linear schedule!")
return transformers.get_polynomial_decay_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
power=1.0,
lr_end=lr_end,
)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.")
def configure_optimizers(self):
"""
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined.
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself).
:return: dict containing optimizer and learning rate scheduler
"""
optimizer = self.get_optimizer(self.config.max_lr, adamw=self.config.adamw,
weight_decay=self.config.weight_decay)
num_training_steps = self.trainer.estimated_stepping_batches
scheduler = self.get_lr_scheduler(optimizer, num_training_steps,
schedule_mode=self.config.schedule_mode,
lr_end=self.config.lr_end)
lr_scheduler_config = {
"scheduler": scheduler,
"interval": "step",
"frequency": 1
}
return [optimizer], [lr_scheduler_config]
def training_step(self, train_batch, batch_idx):
"""
:param train_batch: contains one batch from train dataloader
:param batch_idx
:return: a dict containing at least loss that is used to update model parameters, can also contain
other items that can be processed in 'training_epoch_end' to log other metrics than loss
"""
x = train_batch["audio"]
labels = train_batch['strong']
if 'pseudo_strong' in train_batch:
pseudo_labels = train_batch['pseudo_strong']
else:
# create dummy pseudo labels
pseudo_labels = torch.zeros_like(labels)
assert self.config.distillation_loss_weight == 0
mel = self.model.mel_forward(x)
# time rolling
if self.config.frame_shift_range > 0:
mel, labels, pseudo_labels = frame_shift(
mel,
labels,
pseudo_labels=pseudo_labels,
net_pooling=self.encoder.net_pooling,
shift_range=self.config.frame_shift_range
)
# mixup
if self.config.mixup_p > random.random():
mel, labels, pseudo_labels = mixup(
mel,
targets=labels,
pseudo_strong=pseudo_labels
)
# mixstyle
if self.config.mixstyle_p > random.random():
mel = mixstyle(
mel
)
# time masking
if self.config.max_time_mask_size > 0:
mel, labels, pseudo_labels = time_mask(
mel,
labels,
pseudo_labels=pseudo_labels,
net_pooling=self.encoder.net_pooling,
max_mask_ratio=self.config.max_time_mask_size
)
# frequency masking
if self.config.filter_augment_p > random.random():
mel, _ = filter_augmentation(
mel
)
# frequency warping
if self.config.freq_warp_p > random.random():
mel = mel.squeeze(1)
mel = self.freq_warp(mel)
mel = mel.unsqueeze(1)
# forward through network; use strong head
y_hat_strong, _ = self.model(mel)
strong_supervised_loss = self.strong_loss(y_hat_strong, labels)
if self.config.distillation_loss_weight > 0:
strong_distillation_loss = self.strong_loss(y_hat_strong, pseudo_labels)
else:
strong_distillation_loss = torch.tensor(0., device=y_hat_strong.device, dtype=y_hat_strong.dtype)
loss = self.config.distillation_loss_weight * strong_distillation_loss \
+ (1 - self.config.distillation_loss_weight) * strong_supervised_loss
# logging
self.log('epoch', self.current_epoch)
for i, param_group in enumerate(self.trainer.optimizers[0].param_groups):
self.log(f'trainer/lr_optimizer_{i}', param_group['lr'])
self.log("train/loss", loss.detach().cpu(), prog_bar=True)
self.log("train/strong_supervised_loss", strong_supervised_loss.detach().cpu())
self.log("train/strong_distillation_loss", strong_distillation_loss.detach().cpu())
return loss
def validation_step(self, val_batch, batch_idx):
# bring ground truth into shape needed for evaluation
for f, gt_string in zip(val_batch["filename"], val_batch["gt_string"]):
f = f[:-len(".mp3")]
events = [e.split(";;") for e in gt_string.split("++")]
self.val_ground_truth[f] = [(float(e[0]), float(e[1]), e[2]) for e in events]
self.val_duration[f] = self.val_durations_df[self.val_durations_df["filename"] == f]["duration"].values[0]
y_hat_strong = self(val_batch)
y_strong = val_batch["strong"]
loss = self.strong_loss(y_hat_strong, y_strong)
self.val_loss.append(loss.cpu())
scores_raw, scores_postprocessed, prediction_dfs = batched_decode_preds(
y_hat_strong.float(),
val_batch['filename'],
self.encoder,
median_filter=self.config.median_window
)
self.val_predictions_strong.update(
scores_postprocessed
)
def on_validation_epoch_end(self):
gt_unique_events = set([e[2] for f, events in self.val_ground_truth.items() for e in events])
train_unique_events = set(self.encoder.labels)
# evaluate on all classes that are in both train and test sets (407 classes)
class_intersection = gt_unique_events.intersection(train_unique_events)
assert len(class_intersection) == len(set(as_strong_train_classes).intersection(as_strong_eval_classes)) == 407, \
f"Intersection unique events. Expected: {len(set(as_strong_train_classes).intersection(as_strong_eval_classes))}," \
f" Actual: {len(class_intersection)}"
# filter ground truth according to class_intersection
val_ground_truth = {fid: [event for event in self.val_ground_truth[fid] if event[2] in class_intersection]
for fid in self.val_ground_truth}
# drop audios without events - aligned with DESED evaluation procedure
val_ground_truth = {fid: events for fid, events in val_ground_truth.items() if len(events) > 0}
# keep only corresponding audio durations
audio_durations = {
fid: self.val_duration[fid] for fid in val_ground_truth.keys()
}
# filter files in predictions
as_strong_preds = {
fid: self.val_predictions_strong[fid] for fid in val_ground_truth.keys()
}
# filter classes in predictions
unused_classes = list(set(self.encoder.labels).difference(class_intersection))
for f, df in as_strong_preds.items():
df.drop(columns=list(unused_classes), axis=1, inplace=True)
segment_based_pauroc = sed_scores_eval.segment_based.auroc(
as_strong_preds,
val_ground_truth,
audio_durations,
max_fpr=0.1,
segment_length=1.0,
num_jobs=1
)
psds1 = sed_scores_eval.intersection_based.psds(
as_strong_preds,
val_ground_truth,
audio_durations,
dtc_threshold=0.7,
gtc_threshold=0.7,
cttc_threshold=None,
alpha_ct=0,
alpha_st=1,
num_jobs=1
)
# "val/psds1_macro_averaged" is psds1 without penalization for performance
# variations across classes
logs = {"val/loss": torch.as_tensor(self.val_loss).mean().cuda(),
"val/psds1": psds1[0],
"val/psds1_macro_averaged": np.array([v for k, v in psds1[1].items()]).mean(),
"val/pauroc": segment_based_pauroc[0]['mean'],
}
self.log_dict(logs, sync_dist=False)
self.val_predictions_strong = {}
self.val_ground_truth = {}
self.val_duration = {}
self.val_loss = []
def train(config):
# Train Models on temporally-strong portion of AudioSet.
# logging is done using wandb
wandb_logger = WandbLogger(
project="PTSED",
notes="Pre-Training Transformers for Sound Event Detection on AudioSet Strong.",
tags=["AudioSet Strong", "Sound Event Detection", "Pseudo Labels", "Knowledge Disitillation"],
config=config,
name=config.experiment_name
)
# encoder manages encoding and decoding of model predictions
encoder = ManyHotEncoder(as_strong_train_classes)
train_set = get_training_dataset(encoder, wavmix_p=config.wavmix_p,
pseudo_labels_file=config.pseudo_labels_file)
eval_set = get_eval_dataset(encoder)
if config.use_balanced_sampler:
sample_weights = get_temporal_count_balanced_sample_weights(train_set, save_folder="resources")
else:
sample_weights = get_uniform_sample_weights(train_set)
train_sampler = get_weighted_sampler(sample_weights, epoch_len=config.epoch_len)
# train dataloader
train_dl = DataLoader(dataset=train_set,
sampler=train_sampler,
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size,
shuffle=False)
# eval dataloader
eval_dl = DataLoader(dataset=eval_set,
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size)
# create pytorch lightening module
pl_module = PLModule(config, encoder)
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
# on which kind of device(s) to train and possible callbacks
trainer = pl.Trainer(max_epochs=config.n_epochs,
logger=wandb_logger,
accelerator='auto',
devices=config.num_devices,
precision=config.precision,
num_sanity_val_steps=0,
check_val_every_n_epoch=config.check_val_every_n_epoch
)
# start training and validation for the specified number of epochs
trainer.fit(pl_module, train_dl, eval_dl)
wandb.finish()
def evaluate(config):
# only evaluation of pre-trained models
# encoder manages encoding and decoding of model predictions
encoder = ManyHotEncoder(as_strong_train_classes)
eval_set = get_eval_dataset(encoder)
# eval dataloader
eval_dl = DataLoader(dataset=eval_set,
worker_init_fn=worker_init_fn,
num_workers=config.num_workers,
batch_size=config.batch_size)
# create pytorch lightening module
pl_module = PLModule(config, encoder)
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger,
# on which kind of device(s) to train and possible callbacks
trainer = pl.Trainer(max_epochs=config.n_epochs,
accelerator='auto',
devices=config.num_devices,
precision=config.precision,
num_sanity_val_steps=0,
check_val_every_n_epoch=config.check_val_every_n_epoch)
# start evaluation
trainer.validate(pl_module, eval_dl)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Configuration Parser. ')
# general
parser.add_argument('--experiment_name', type=str, default="AudioSet_Strong")
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--num_devices', type=int, default=1)
parser.add_argument('--precision', type=int, default=16)
parser.add_argument('--evaluate', action='store_true', default=False)
parser.add_argument('--check_val_every_n_epoch', type=int, default=5)
# model
parser.add_argument('--model_name', type=str,
choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"] + \
[f"frame_mn{width}" for width in ["06", "10"]],
default="ATST-F") # used also for training
# "scratch" = no pretraining
# "ssl" = SSL pre-trained
# "weak" = AudioSet Weak pre-trained
# "strong" = AudioSet Strong pre-trained
parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"],
default="weak")
parser.add_argument('--seq_model_type', type=str, choices=["rnn"],
default=None)
# training
parser.add_argument('--n_epochs', type=int, default=30)
parser.add_argument('--use_balanced_sampler', action='store_true', default=False)
parser.add_argument('--distillation_loss_weight', type=float, default=0.0)
parser.add_argument('--epoch_len', type=int, default=100000)
parser.add_argument('--median_window', type=int, default=9)
# augmentation
parser.add_argument('--wavmix_p', type=float, default=0.8)
parser.add_argument('--freq_warp_p', type=float, default=0.8)
parser.add_argument('--filter_augment_p', type=float, default=0.8)
parser.add_argument('--frame_shift_range', type=float, default=0.125) # in seconds
parser.add_argument('--mixup_p', type=float, default=0.3)
parser.add_argument('--mixstyle_p', type=float, default=0.3)
parser.add_argument('--max_time_mask_size', type=float, default=0.0)
# optimizer
parser.add_argument('--adamw', action='store_true', default=False)
parser.add_argument('--weight_decay', type=float, default=0.0)
# lr schedule
parser.add_argument('--schedule_mode', type=str, default="cos")
parser.add_argument('--max_lr', type=float, default=7e-5)
parser.add_argument('--lr_end', type=float, default=2e-7)
parser.add_argument('--warmup_steps', type=int, default=5000)
# knowledge distillation
parser.add_argument('--pseudo_labels_file', type=str,
default=None)
args = parser.parse_args()
if args.evaluate:
evaluate(args)
else:
train(args)