Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
import torch | |
from torch.utils.data import DataLoader | |
import argparse | |
import torch.nn as nn | |
import wandb | |
import transformers | |
import random | |
import pytorch_lightning as pl | |
from pytorch_lightning.loggers import WandbLogger | |
import sed_scores_eval | |
from helpers.decode import batched_decode_preds | |
from helpers.encode import ManyHotEncoder | |
from models.atstframe.ATSTF_wrapper import ATSTWrapper | |
from models.beats.BEATs_wrapper import BEATsWrapper | |
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper | |
from models.m2d.M2D_wrapper import M2DWrapper | |
from models.asit.ASIT_wrapper import ASiTWrapper | |
from models.prediction_wrapper import PredictionsWrapper | |
from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop | |
from helpers.utils import worker_init_fn | |
from data_util.audioset_strong import get_training_dataset, get_eval_dataset | |
from data_util.audioset_strong import get_temporal_count_balanced_sample_weights, get_uniform_sample_weights, \ | |
get_weighted_sampler | |
from data_util.audioset_classes import as_strong_train_classes, as_strong_eval_classes | |
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper | |
from models.frame_mn.utils import NAME_TO_WIDTH | |
class PLModule(pl.LightningModule): | |
def __init__(self, config, encoder): | |
super().__init__() | |
self.config = config | |
self.encoder = encoder | |
if config.pretrained == "scratch": | |
checkpoint = None | |
elif config.pretrained == "ssl": | |
checkpoint = "ssl" | |
elif config.pretrained == "weak": | |
checkpoint = "weak" | |
elif config.pretrained == "strong": | |
checkpoint = "strong_1" | |
else: | |
raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}") | |
# load transformer model | |
if config.model_name == "BEATs": | |
beats = BEATsWrapper() | |
model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type) | |
elif config.model_name == "ATST-F": | |
atst = ATSTWrapper() | |
model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type) | |
elif config.model_name == "fpasst": | |
fpasst = FPaSSTWrapper() | |
model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type) | |
elif config.model_name == "M2D": | |
m2d = M2DWrapper() | |
model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type, | |
embed_dim=m2d.m2d.cfg.feature_d) | |
elif config.model_name == "ASIT": | |
asit = ASiTWrapper() | |
model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type) | |
elif config.model_name.startswith("frame_mn"): | |
width = NAME_TO_WIDTH(config.model_name) | |
frame_mn = FrameMNWrapper(width) | |
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0] | |
model = PredictionsWrapper(frame_mn, checkpoint=f"{config.model_name}_strong_1", embed_dim=embed_dim) | |
else: | |
raise NotImplementedError(f"Model {config.model_name} not (yet) implemented") | |
self.model = model | |
# prepare ingredients for knowledge distillation | |
assert 0 <= config.distillation_loss_weight <= 1, "Lambda for Knowledge Distillation must be between 0 and 1." | |
self.strong_loss = nn.BCEWithLogitsLoss() | |
self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0)) | |
self.val_durations_df = pd.read_csv(f"resources/eval_durations.csv", | |
sep=",", header=None, names=["filename", "duration"]) | |
self.val_predictions_strong = {} | |
self.val_ground_truth = {} | |
self.val_duration = {} | |
self.val_loss = [] | |
def forward(self, batch): | |
x = batch["audio"] | |
mel = self.model.mel_forward(x) | |
y_strong, _ = self.model(mel) | |
return y_strong | |
def get_optimizer( | |
self, lr, adamw=False, weight_decay=0.01, betas=(0.9, 0.999) | |
): | |
# we split the parameters into two groups, one for the pretrained model and one for the downstream model | |
# we also split each of them into <=1 dimensional and >=2 dimensional parameters, so we can only | |
# apply weight decay to the >=2 dimensional parameters, thus excluding biases and batch norms, an idea from NanoGPT | |
params_leq1D = [] | |
params_geq2D = [] | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
if param.ndimension() >= 2: | |
params_geq2D.append(param) | |
else: | |
params_leq1D.append(param) | |
param_groups = [ | |
{'params': params_leq1D, 'lr': lr}, | |
{'params': params_geq2D, 'lr': lr, 'weight_decay': weight_decay}, | |
] | |
if weight_decay > 0: | |
assert adamw | |
assert len(param_groups) > 0 | |
if adamw: | |
print(f"\nUsing adamw weight_decay={weight_decay}!\n") | |
return torch.optim.AdamW(param_groups, lr=lr, betas=betas) | |
return torch.optim.Adam(param_groups, lr=lr, betas=betas) | |
def get_lr_scheduler( | |
self, | |
optimizer, | |
num_training_steps, | |
schedule_mode="cos", | |
gamma: float = 0.999996, | |
num_warmup_steps=20000, | |
lr_end=2e-7, | |
): | |
if schedule_mode in {"exp"}: | |
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) | |
if schedule_mode in {"cosine", "cos"}: | |
return transformers.get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
) | |
if schedule_mode in {"linear"}: | |
print("Linear schedule!") | |
return transformers.get_polynomial_decay_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
power=1.0, | |
lr_end=lr_end, | |
) | |
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") | |
def configure_optimizers(self): | |
""" | |
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined. | |
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself). | |
:return: dict containing optimizer and learning rate scheduler | |
""" | |
optimizer = self.get_optimizer(self.config.max_lr, adamw=self.config.adamw, | |
weight_decay=self.config.weight_decay) | |
num_training_steps = self.trainer.estimated_stepping_batches | |
scheduler = self.get_lr_scheduler(optimizer, num_training_steps, | |
schedule_mode=self.config.schedule_mode, | |
lr_end=self.config.lr_end) | |
lr_scheduler_config = { | |
"scheduler": scheduler, | |
"interval": "step", | |
"frequency": 1 | |
} | |
return [optimizer], [lr_scheduler_config] | |
def training_step(self, train_batch, batch_idx): | |
""" | |
:param train_batch: contains one batch from train dataloader | |
:param batch_idx | |
:return: a dict containing at least loss that is used to update model parameters, can also contain | |
other items that can be processed in 'training_epoch_end' to log other metrics than loss | |
""" | |
x = train_batch["audio"] | |
labels = train_batch['strong'] | |
if 'pseudo_strong' in train_batch: | |
pseudo_labels = train_batch['pseudo_strong'] | |
else: | |
# create dummy pseudo labels | |
pseudo_labels = torch.zeros_like(labels) | |
assert self.config.distillation_loss_weight == 0 | |
mel = self.model.mel_forward(x) | |
# time rolling | |
if self.config.frame_shift_range > 0: | |
mel, labels, pseudo_labels = frame_shift( | |
mel, | |
labels, | |
pseudo_labels=pseudo_labels, | |
net_pooling=self.encoder.net_pooling, | |
shift_range=self.config.frame_shift_range | |
) | |
# mixup | |
if self.config.mixup_p > random.random(): | |
mel, labels, pseudo_labels = mixup( | |
mel, | |
targets=labels, | |
pseudo_strong=pseudo_labels | |
) | |
# mixstyle | |
if self.config.mixstyle_p > random.random(): | |
mel = mixstyle( | |
mel | |
) | |
# time masking | |
if self.config.max_time_mask_size > 0: | |
mel, labels, pseudo_labels = time_mask( | |
mel, | |
labels, | |
pseudo_labels=pseudo_labels, | |
net_pooling=self.encoder.net_pooling, | |
max_mask_ratio=self.config.max_time_mask_size | |
) | |
# frequency masking | |
if self.config.filter_augment_p > random.random(): | |
mel, _ = filter_augmentation( | |
mel | |
) | |
# frequency warping | |
if self.config.freq_warp_p > random.random(): | |
mel = mel.squeeze(1) | |
mel = self.freq_warp(mel) | |
mel = mel.unsqueeze(1) | |
# forward through network; use strong head | |
y_hat_strong, _ = self.model(mel) | |
strong_supervised_loss = self.strong_loss(y_hat_strong, labels) | |
if self.config.distillation_loss_weight > 0: | |
strong_distillation_loss = self.strong_loss(y_hat_strong, pseudo_labels) | |
else: | |
strong_distillation_loss = torch.tensor(0., device=y_hat_strong.device, dtype=y_hat_strong.dtype) | |
loss = self.config.distillation_loss_weight * strong_distillation_loss \ | |
+ (1 - self.config.distillation_loss_weight) * strong_supervised_loss | |
# logging | |
self.log('epoch', self.current_epoch) | |
for i, param_group in enumerate(self.trainer.optimizers[0].param_groups): | |
self.log(f'trainer/lr_optimizer_{i}', param_group['lr']) | |
self.log("train/loss", loss.detach().cpu(), prog_bar=True) | |
self.log("train/strong_supervised_loss", strong_supervised_loss.detach().cpu()) | |
self.log("train/strong_distillation_loss", strong_distillation_loss.detach().cpu()) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
# bring ground truth into shape needed for evaluation | |
for f, gt_string in zip(val_batch["filename"], val_batch["gt_string"]): | |
f = f[:-len(".mp3")] | |
events = [e.split(";;") for e in gt_string.split("++")] | |
self.val_ground_truth[f] = [(float(e[0]), float(e[1]), e[2]) for e in events] | |
self.val_duration[f] = self.val_durations_df[self.val_durations_df["filename"] == f]["duration"].values[0] | |
y_hat_strong = self(val_batch) | |
y_strong = val_batch["strong"] | |
loss = self.strong_loss(y_hat_strong, y_strong) | |
self.val_loss.append(loss.cpu()) | |
scores_raw, scores_postprocessed, prediction_dfs = batched_decode_preds( | |
y_hat_strong.float(), | |
val_batch['filename'], | |
self.encoder, | |
median_filter=self.config.median_window | |
) | |
self.val_predictions_strong.update( | |
scores_postprocessed | |
) | |
def on_validation_epoch_end(self): | |
gt_unique_events = set([e[2] for f, events in self.val_ground_truth.items() for e in events]) | |
train_unique_events = set(self.encoder.labels) | |
# evaluate on all classes that are in both train and test sets (407 classes) | |
class_intersection = gt_unique_events.intersection(train_unique_events) | |
assert len(class_intersection) == len(set(as_strong_train_classes).intersection(as_strong_eval_classes)) == 407, \ | |
f"Intersection unique events. Expected: {len(set(as_strong_train_classes).intersection(as_strong_eval_classes))}," \ | |
f" Actual: {len(class_intersection)}" | |
# filter ground truth according to class_intersection | |
val_ground_truth = {fid: [event for event in self.val_ground_truth[fid] if event[2] in class_intersection] | |
for fid in self.val_ground_truth} | |
# drop audios without events - aligned with DESED evaluation procedure | |
val_ground_truth = {fid: events for fid, events in val_ground_truth.items() if len(events) > 0} | |
# keep only corresponding audio durations | |
audio_durations = { | |
fid: self.val_duration[fid] for fid in val_ground_truth.keys() | |
} | |
# filter files in predictions | |
as_strong_preds = { | |
fid: self.val_predictions_strong[fid] for fid in val_ground_truth.keys() | |
} | |
# filter classes in predictions | |
unused_classes = list(set(self.encoder.labels).difference(class_intersection)) | |
for f, df in as_strong_preds.items(): | |
df.drop(columns=list(unused_classes), axis=1, inplace=True) | |
segment_based_pauroc = sed_scores_eval.segment_based.auroc( | |
as_strong_preds, | |
val_ground_truth, | |
audio_durations, | |
max_fpr=0.1, | |
segment_length=1.0, | |
num_jobs=1 | |
) | |
psds1 = sed_scores_eval.intersection_based.psds( | |
as_strong_preds, | |
val_ground_truth, | |
audio_durations, | |
dtc_threshold=0.7, | |
gtc_threshold=0.7, | |
cttc_threshold=None, | |
alpha_ct=0, | |
alpha_st=1, | |
num_jobs=1 | |
) | |
# "val/psds1_macro_averaged" is psds1 without penalization for performance | |
# variations across classes | |
logs = {"val/loss": torch.as_tensor(self.val_loss).mean().cuda(), | |
"val/psds1": psds1[0], | |
"val/psds1_macro_averaged": np.array([v for k, v in psds1[1].items()]).mean(), | |
"val/pauroc": segment_based_pauroc[0]['mean'], | |
} | |
self.log_dict(logs, sync_dist=False) | |
self.val_predictions_strong = {} | |
self.val_ground_truth = {} | |
self.val_duration = {} | |
self.val_loss = [] | |
def train(config): | |
# Train Models on temporally-strong portion of AudioSet. | |
# logging is done using wandb | |
wandb_logger = WandbLogger( | |
project="PTSED", | |
notes="Pre-Training Transformers for Sound Event Detection on AudioSet Strong.", | |
tags=["AudioSet Strong", "Sound Event Detection", "Pseudo Labels", "Knowledge Disitillation"], | |
config=config, | |
name=config.experiment_name | |
) | |
# encoder manages encoding and decoding of model predictions | |
encoder = ManyHotEncoder(as_strong_train_classes) | |
train_set = get_training_dataset(encoder, wavmix_p=config.wavmix_p, | |
pseudo_labels_file=config.pseudo_labels_file) | |
eval_set = get_eval_dataset(encoder) | |
if config.use_balanced_sampler: | |
sample_weights = get_temporal_count_balanced_sample_weights(train_set, save_folder="resources") | |
else: | |
sample_weights = get_uniform_sample_weights(train_set) | |
train_sampler = get_weighted_sampler(sample_weights, epoch_len=config.epoch_len) | |
# train dataloader | |
train_dl = DataLoader(dataset=train_set, | |
sampler=train_sampler, | |
worker_init_fn=worker_init_fn, | |
num_workers=config.num_workers, | |
batch_size=config.batch_size, | |
shuffle=False) | |
# eval dataloader | |
eval_dl = DataLoader(dataset=eval_set, | |
worker_init_fn=worker_init_fn, | |
num_workers=config.num_workers, | |
batch_size=config.batch_size) | |
# create pytorch lightening module | |
pl_module = PLModule(config, encoder) | |
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger, | |
# on which kind of device(s) to train and possible callbacks | |
trainer = pl.Trainer(max_epochs=config.n_epochs, | |
logger=wandb_logger, | |
accelerator='auto', | |
devices=config.num_devices, | |
precision=config.precision, | |
num_sanity_val_steps=0, | |
check_val_every_n_epoch=config.check_val_every_n_epoch | |
) | |
# start training and validation for the specified number of epochs | |
trainer.fit(pl_module, train_dl, eval_dl) | |
wandb.finish() | |
def evaluate(config): | |
# only evaluation of pre-trained models | |
# encoder manages encoding and decoding of model predictions | |
encoder = ManyHotEncoder(as_strong_train_classes) | |
eval_set = get_eval_dataset(encoder) | |
# eval dataloader | |
eval_dl = DataLoader(dataset=eval_set, | |
worker_init_fn=worker_init_fn, | |
num_workers=config.num_workers, | |
batch_size=config.batch_size) | |
# create pytorch lightening module | |
pl_module = PLModule(config, encoder) | |
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger, | |
# on which kind of device(s) to train and possible callbacks | |
trainer = pl.Trainer(max_epochs=config.n_epochs, | |
accelerator='auto', | |
devices=config.num_devices, | |
precision=config.precision, | |
num_sanity_val_steps=0, | |
check_val_every_n_epoch=config.check_val_every_n_epoch) | |
# start evaluation | |
trainer.validate(pl_module, eval_dl) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Configuration Parser. ') | |
# general | |
parser.add_argument('--experiment_name', type=str, default="AudioSet_Strong") | |
parser.add_argument('--batch_size', type=int, default=256) | |
parser.add_argument('--num_workers', type=int, default=16) | |
parser.add_argument('--num_devices', type=int, default=1) | |
parser.add_argument('--precision', type=int, default=16) | |
parser.add_argument('--evaluate', action='store_true', default=False) | |
parser.add_argument('--check_val_every_n_epoch', type=int, default=5) | |
# model | |
parser.add_argument('--model_name', type=str, | |
choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"] + \ | |
[f"frame_mn{width}" for width in ["06", "10"]], | |
default="ATST-F") # used also for training | |
# "scratch" = no pretraining | |
# "ssl" = SSL pre-trained | |
# "weak" = AudioSet Weak pre-trained | |
# "strong" = AudioSet Strong pre-trained | |
parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"], | |
default="weak") | |
parser.add_argument('--seq_model_type', type=str, choices=["rnn"], | |
default=None) | |
# training | |
parser.add_argument('--n_epochs', type=int, default=30) | |
parser.add_argument('--use_balanced_sampler', action='store_true', default=False) | |
parser.add_argument('--distillation_loss_weight', type=float, default=0.0) | |
parser.add_argument('--epoch_len', type=int, default=100000) | |
parser.add_argument('--median_window', type=int, default=9) | |
# augmentation | |
parser.add_argument('--wavmix_p', type=float, default=0.8) | |
parser.add_argument('--freq_warp_p', type=float, default=0.8) | |
parser.add_argument('--filter_augment_p', type=float, default=0.8) | |
parser.add_argument('--frame_shift_range', type=float, default=0.125) # in seconds | |
parser.add_argument('--mixup_p', type=float, default=0.3) | |
parser.add_argument('--mixstyle_p', type=float, default=0.3) | |
parser.add_argument('--max_time_mask_size', type=float, default=0.0) | |
# optimizer | |
parser.add_argument('--adamw', action='store_true', default=False) | |
parser.add_argument('--weight_decay', type=float, default=0.0) | |
# lr schedule | |
parser.add_argument('--schedule_mode', type=str, default="cos") | |
parser.add_argument('--max_lr', type=float, default=7e-5) | |
parser.add_argument('--lr_end', type=float, default=2e-7) | |
parser.add_argument('--warmup_steps', type=int, default=5000) | |
# knowledge distillation | |
parser.add_argument('--pseudo_labels_file', type=str, | |
default=None) | |
args = parser.parse_args() | |
if args.evaluate: | |
evaluate(args) | |
else: | |
train(args) | |