import argparse import random from pathlib import Path from typing import Dict import pytorch_lightning as pl import torch import torch.nn as nn import transformers from einops import rearrange from pytorch_lightning.loggers import WandbLogger from torch.utils.data import DataLoader import wandb from data_util.dcase2016task2 import (get_training_dataset, get_validation_dataset, get_test_dataset, label_vocab_nlabels, label_vocab_as_dict) from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop from helpers.score import get_events_for_all_files, combine_target_events, EventBasedScore, SegmentBasedScore from helpers.utils import worker_init_fn from models.asit.ASIT_wrapper import ASiTWrapper from models.atstframe.ATSTF_wrapper import ATSTWrapper from models.beats.BEATs_wrapper import BEATsWrapper from models.frame_passt.fpasst_wrapper import FPaSSTWrapper from models.m2d.M2D_wrapper import M2DWrapper from models.prediction_wrapper import PredictionsWrapper class PLModule(pl.LightningModule): def __init__(self, config): super().__init__() self.config = config if config.pretrained == "scratch": checkpoint = None elif config.pretrained == "ssl": checkpoint = "ssl" elif config.pretrained == "weak": checkpoint = "weak" elif config.pretrained == "strong": checkpoint = "strong_1" else: raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}") # load transformer model if config.model_name == "BEATs": beats = BEATsWrapper() model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None, seq_model_type=config.seq_model_type, n_classes_strong=self.config.n_classes) elif config.model_name == "ATST-F": atst = ATSTWrapper() model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None, seq_model_type=config.seq_model_type, n_classes_strong=self.config.n_classes) elif config.model_name == "fpasst": fpasst = FPaSSTWrapper() model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None, seq_model_type=config.seq_model_type, n_classes_strong=self.config.n_classes) elif config.model_name == "M2D": m2d = M2DWrapper() model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None, seq_model_type=config.seq_model_type, n_classes_strong=self.config.n_classes, embed_dim=m2d.m2d.cfg.feature_d) elif config.model_name == "ASIT": asit = ASiTWrapper() model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None, seq_model_type=config.seq_model_type, n_classes_strong=self.config.n_classes) else: raise NotImplementedError(f"Model {config.model_name} not (yet) implemented") self.model = model self.strong_loss = nn.BCEWithLogitsLoss() self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0)) task_path = Path(self.config.task_path) label_vocab, nlabels = label_vocab_nlabels(task_path) self.label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx") self.idx_to_label: Dict[int, str] = { idx: label for (label, idx) in self.label_to_idx.items() } self.event_onset_200ms_fms = EventBasedScore( label_to_idx=self.label_to_idx, name="event_onset_200ms_fms", scores=("f_measure", "precision", "recall"), params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.2} ) self.event_onset_50ms_fms = EventBasedScore( label_to_idx=self.label_to_idx, name="event_onset_50ms_fms", scores=("f_measure", "precision", "recall"), params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.05} ) self.segment_1s_er = SegmentBasedScore( label_to_idx=self.label_to_idx, name="segment_1s_er", scores=("error_rate",), params={"time_resolution": 1.0}, maximize=False, ) self.postprocessing_grid = { "median_filter_ms": [ 250 ], "min_duration": [ 125 ] } self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], [] def forward(self, audio): mel = self.model.mel_forward(audio) y_strong, _ = self.model(mel) return y_strong def separate_params(self): pt_params = [] seq_params = [] head_params = [] for name, p in self.named_parameters(): name = name[len("model."):] if name.startswith('model'): # the transformer pt_params.append(p) elif name.startswith('seq_model'): # the optional sequence model seq_params.append(p) elif name.startswith('strong_head') or name.startswith('weak_head'): # the prediction head head_params.append(p) else: raise ValueError(f"Unexpected key in model: {name}") if self.model.has_separate_params(): # split parameters into groups according to their depth in the network # based on this, we can apply layer-wise learning rate decay pt_params = self.model.separate_params() else: if self.config.lr_decay != 1.0: raise ValueError(f"Model has no separate_params function. Can't apply layer-wise lr decay, but " f"learning rate decay is set to {self.config.lr_decay}.") return pt_params, seq_params, head_params def get_optimizer( self, lr, lr_decay=1.0, transformer_lr=None, transformer_frozen=False, adamw=False, weight_decay=0.01, betas=(0.9, 0.999) ): pt_params, seq_params, head_params = self.separate_params() param_groups = [ {'params': head_params, 'lr': lr}, # model head (besides base model and seq model) ] if transformer_frozen: for p in pt_params + seq_params: if isinstance(p, list): for p_i in p: p_i.detach_() else: p.detach_() else: if transformer_lr is None: transformer_lr = lr if isinstance(pt_params, list) and isinstance(pt_params[0], list): # apply lr decay scale_lrs = [transformer_lr * (lr_decay ** i) for i in range(1, len(pt_params) + 1)] param_groups = param_groups + [{"params": pt_params[i], "lr": scale_lrs[i]} for i in range(len(pt_params))] else: param_groups.append( {'params': pt_params, 'lr': transformer_lr}, # pretrained model ) param_groups.append( {'params': seq_params, 'lr': lr}, # pretrained model ) # do not apply weight decay to biases and batch norms param_groups_split = [] for param_group in param_groups: params_1D, params_2D = [], [] lr = param_group['lr'] for param in param_group['params']: if param.ndimension() >= 2: params_2D.append(param) elif param.ndimension() <= 1: params_1D.append(param) param_groups_split += [{'params': params_2D, 'lr': lr, 'weight_decay': weight_decay}, {'params': params_1D, 'lr': lr}] if weight_decay > 0: assert adamw if adamw: print(f"\nUsing adamw weight_decay={weight_decay}!\n") return torch.optim.AdamW(param_groups_split, lr=lr, weight_decay=weight_decay, betas=betas) return torch.optim.Adam(param_groups_split, lr=lr, betas=betas) def get_lr_scheduler( self, optimizer, num_training_steps, schedule_mode="cos", gamma: float = 0.999996, num_warmup_steps=4000, lr_end=1e-7, ): if schedule_mode in {"exp"}: return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) if schedule_mode in {"cosine", "cos"}: return transformers.get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) if schedule_mode in {"linear"}: print("Linear schedule!") return transformers.get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=1.0, lr_end=lr_end, ) raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") def configure_optimizers(self): """ This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined. The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself). :return: dict containing optimizer and learning rate scheduler """ optimizer = self.get_optimizer(self.config.max_lr, lr_decay=self.config.lr_decay, transformer_lr=self.config.transformer_lr, transformer_frozen=self.config.transformer_frozen, adamw=False if self.config.no_adamw else True, weight_decay=self.config.weight_decay) num_training_steps = self.trainer.estimated_stepping_batches scheduler = self.get_lr_scheduler(optimizer, num_training_steps, schedule_mode=self.config.schedule_mode, lr_end=self.config.lr_end) lr_scheduler_config = { "scheduler": scheduler, "interval": "step", "frequency": 1 } return [optimizer], [lr_scheduler_config] def training_step(self, train_batch, batch_idx): """ :param train_batch: contains one batch from train dataloader :param batch_idx :return: a dict containing at least loss that is used to update model parameters, can also contain other items that can be processed in 'training_epoch_end' to log other metrics than loss """ audios, labels, fnames, timestamps = train_batch if self.config.transformer_frozen: self.model.model.eval() self.model.seq_model.eval() mel = self.model.mel_forward(audios) # time rolling if self.config.frame_shift_range > 0: mel, labels = frame_shift( mel, labels, shift_range=self.config.frame_shift_range ) # mixup if self.config.mixup_p > random.random(): mel, labels = mixup( mel, targets=labels ) # mixstyle if self.config.mixstyle_p > random.random(): mel = mixstyle( mel ) # time masking if self.config.max_time_mask_size > 0: mel, labels, pseudo_labels = time_mask( mel, labels, max_mask_ratio=self.config.max_time_mask_size ) # frequency masking if self.config.filter_augment_p > random.random(): mel, _ = filter_augmentation( mel ) # frequency warping if self.config.freq_warp_p > random.random(): mel = mel.squeeze(1) mel = self.freq_warp(mel) mel = mel.unsqueeze(1) # forward through network; use strong head y_hat_strong, _ = self.model(mel) loss = self.strong_loss(y_hat_strong, labels) # logging self.log('epoch', self.current_epoch) for i, param_group in enumerate(self.trainer.optimizers[0].param_groups): self.log(f'trainer/lr_optimizer_{i}', param_group['lr']) self.log("train/loss", loss.detach().cpu(), prog_bar=True) return loss def _score_step(self, batch): audios, labels, fnames, timestamps = batch strong_preds = self.forward(audios) self.preds.append(strong_preds) self.tgts.append(labels) self.fnames.append(fnames) self.timestamps.append(timestamps) def _score_epoch_end(self, name="val"): preds = torch.cat(self.preds) tgts = torch.cat(self.tgts) fnames = [item for sublist in self.fnames for item in sublist] timestamps = torch.cat(self.timestamps) val_loss = self.strong_loss(preds, tgts) self.log(f"{name}/loss", val_loss, prog_bar=True) # the following function expects one prediction per timestamp (sequence dimension must be flattened) seq_len = preds.size(-1) preds = rearrange(preds, 'bs c t -> (bs t) c').float() timestamps = rearrange(timestamps, 'bs t -> (bs t)').float() fnames = [fname for fname in fnames for _ in range(seq_len)] predicted_events_by_postprocessing = get_events_for_all_files( preds, fnames, timestamps, self.idx_to_label, self.postprocessing_grid ) # we only have one postprocessing configurations (aligned with HEAR challenge) key = list(predicted_events_by_postprocessing.keys())[0] predicted_events = predicted_events_by_postprocessing[key] # load ground truth for test fold task_path = Path(self.config.task_path) test_target_events = combine_target_events(["valid" if name == "val" else "test"], task_path) onset_fms = self.event_onset_200ms_fms(predicted_events, test_target_events) onset_fms_50 = self.event_onset_50ms_fms(predicted_events, test_target_events) segment_1s_er = self.segment_1s_er(predicted_events, test_target_events) self.log(f"{name}/onset_fms", onset_fms[0][1]) self.log(f"{name}/onset_fms_50", onset_fms_50[0][1]) self.log(f"{name}/segment_1s_er", segment_1s_er[0][1]) # free buffers self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], [] def validation_step(self, batch, batch_idx): self._score_step(batch) def on_validation_epoch_end(self): self._score_epoch_end(name="val") def test_step(self, batch, batch_idx): self._score_step(batch) def on_test_epoch_end(self): self._score_epoch_end(name="test") def train(config): # Example for fine-tuning pre-trained transformers on a downstream task. # logging is done using wandb wandb_logger = WandbLogger( project="PTSED", notes="Downstream Training on office sound event detection.", tags=["DCASE 2016 Task 2", "Sound Event Detection"], config=config, name=config.experiment_name ) train_set = get_training_dataset(config.task_path, wavmix_p=config.wavmix_p) val_ds = get_validation_dataset(config.task_path) test_ds = get_test_dataset(config.task_path) # train dataloader train_dl = DataLoader(dataset=train_set, worker_init_fn=worker_init_fn, num_workers=config.num_workers, batch_size=config.batch_size, shuffle=True) # validation dataloader valid_dl = DataLoader(dataset=val_ds, worker_init_fn=worker_init_fn, num_workers=config.num_workers, batch_size=config.batch_size, shuffle=False, drop_last=False) # test dataloader test_dl = DataLoader(dataset=test_ds, worker_init_fn=worker_init_fn, num_workers=config.num_workers, batch_size=config.batch_size, shuffle=False, drop_last=False) # create pytorch lightening module pl_module = PLModule(config) # create the pytorch lightening trainer by specifying the number of epochs to train, the logger, # on which kind of device(s) to train and possible callbacks trainer = pl.Trainer(max_epochs=config.n_epochs, logger=wandb_logger, accelerator='auto', devices=config.num_devices, precision=config.precision, num_sanity_val_steps=0, check_val_every_n_epoch=config.check_val_every_n_epoch ) # start training and validation for the specified number of epochs trainer.fit( pl_module, train_dataloaders=train_dl, val_dataloaders=valid_dl, ) test_results = trainer.test(pl_module, dataloaders=test_dl) print(test_results) wandb.finish() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Configuration Parser. ') # general parser.add_argument('--task_path', type=str, required=True) parser.add_argument('--experiment_name', type=str, default="DCASE2016Task2") parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--num_workers', type=int, default=16) parser.add_argument('--num_devices', type=int, default=1) parser.add_argument('--precision', type=int, default=16) parser.add_argument('--check_val_every_n_epoch', type=int, default=10) # model parser.add_argument('--model_name', type=str, choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"], default="ATST-F") # used also for training # "scratch" = no pretraining # "ssl" = SSL pre-trained # "weak" = AudioSet Weak pre-trained # "strong" = AudioSet Strong pre-trained parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"], default="strong") parser.add_argument('--seq_model_type', type=str, choices=["rnn"], default=None) parser.add_argument('--n_classes', type=int, default=11) # training parser.add_argument('--n_epochs', type=int, default=300) # augmentation parser.add_argument('--wavmix_p', type=float, default=0.5) parser.add_argument('--freq_warp_p', type=float, default=0.0) parser.add_argument('--filter_augment_p', type=float, default=0.0) parser.add_argument('--frame_shift_range', type=float, default=0.0) # in seconds parser.add_argument('--mixup_p', type=float, default=0.5) parser.add_argument('--mixstyle_p', type=float, default=0.0) parser.add_argument('--max_time_mask_size', type=float, default=0.0) # optimizer parser.add_argument('--no_adamw', action='store_true', default=False) parser.add_argument('--weight_decay', type=float, default=0.001) parser.add_argument('--transformer_frozen', action='store_true', dest='transformer_frozen', default=False, help='Disable training for the transformer.') # lr schedule parser.add_argument('--schedule_mode', type=str, default="cos") parser.add_argument('--max_lr', type=float, default=1.06e-4) parser.add_argument('--transformer_lr', type=float, default=None) parser.add_argument('--lr_decay', type=float, default=1.0) parser.add_argument('--lr_end', type=float, default=1e-7) parser.add_argument('--warmup_steps', type=int, default=100) args = parser.parse_args() train(args)