Spaces:
Sleeping
Sleeping
import argparse | |
import random | |
from pathlib import Path | |
from typing import Dict | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import transformers | |
from einops import rearrange | |
from pytorch_lightning.loggers import WandbLogger | |
from torch.utils.data import DataLoader | |
import wandb | |
from data_util.dcase2016task2 import (get_training_dataset, get_validation_dataset, get_test_dataset, | |
label_vocab_nlabels, label_vocab_as_dict) | |
from helpers.augment import frame_shift, time_mask, mixup, filter_augmentation, mixstyle, RandomResizeCrop | |
from helpers.score import get_events_for_all_files, combine_target_events, EventBasedScore, SegmentBasedScore | |
from helpers.utils import worker_init_fn | |
from models.asit.ASIT_wrapper import ASiTWrapper | |
from models.atstframe.ATSTF_wrapper import ATSTWrapper | |
from models.beats.BEATs_wrapper import BEATsWrapper | |
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper | |
from models.m2d.M2D_wrapper import M2DWrapper | |
from models.prediction_wrapper import PredictionsWrapper | |
class PLModule(pl.LightningModule): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
if config.pretrained == "scratch": | |
checkpoint = None | |
elif config.pretrained == "ssl": | |
checkpoint = "ssl" | |
elif config.pretrained == "weak": | |
checkpoint = "weak" | |
elif config.pretrained == "strong": | |
checkpoint = "strong_1" | |
else: | |
raise ValueError(f"Unknown pretrained checkpoint: {config.pretrained}") | |
# load transformer model | |
if config.model_name == "BEATs": | |
beats = BEATsWrapper() | |
model = PredictionsWrapper(beats, checkpoint=f"BEATs_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type, | |
n_classes_strong=self.config.n_classes) | |
elif config.model_name == "ATST-F": | |
atst = ATSTWrapper() | |
model = PredictionsWrapper(atst, checkpoint=f"ATST-F_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type, | |
n_classes_strong=self.config.n_classes) | |
elif config.model_name == "fpasst": | |
fpasst = FPaSSTWrapper() | |
model = PredictionsWrapper(fpasst, checkpoint=f"fpasst_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type, | |
n_classes_strong=self.config.n_classes) | |
elif config.model_name == "M2D": | |
m2d = M2DWrapper() | |
model = PredictionsWrapper(m2d, checkpoint=f"M2D_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type, | |
n_classes_strong=self.config.n_classes, | |
embed_dim=m2d.m2d.cfg.feature_d) | |
elif config.model_name == "ASIT": | |
asit = ASiTWrapper() | |
model = PredictionsWrapper(asit, checkpoint=f"ASIT_{checkpoint}" if checkpoint else None, | |
seq_model_type=config.seq_model_type, | |
n_classes_strong=self.config.n_classes) | |
else: | |
raise NotImplementedError(f"Model {config.model_name} not (yet) implemented") | |
self.model = model | |
self.strong_loss = nn.BCEWithLogitsLoss() | |
self.freq_warp = RandomResizeCrop((1, 1.0), time_scale=(1.0, 1.0)) | |
task_path = Path(self.config.task_path) | |
label_vocab, nlabels = label_vocab_nlabels(task_path) | |
self.label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx") | |
self.idx_to_label: Dict[int, str] = { | |
idx: label for (label, idx) in self.label_to_idx.items() | |
} | |
self.event_onset_200ms_fms = EventBasedScore( | |
label_to_idx=self.label_to_idx, | |
name="event_onset_200ms_fms", | |
scores=("f_measure", "precision", "recall"), | |
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.2} | |
) | |
self.event_onset_50ms_fms = EventBasedScore( | |
label_to_idx=self.label_to_idx, | |
name="event_onset_50ms_fms", | |
scores=("f_measure", "precision", "recall"), | |
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.05} | |
) | |
self.segment_1s_er = SegmentBasedScore( | |
label_to_idx=self.label_to_idx, | |
name="segment_1s_er", | |
scores=("error_rate",), | |
params={"time_resolution": 1.0}, | |
maximize=False, | |
) | |
self.postprocessing_grid = { | |
"median_filter_ms": [ | |
250 | |
], | |
"min_duration": [ | |
125 | |
] | |
} | |
self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], [] | |
def forward(self, audio): | |
mel = self.model.mel_forward(audio) | |
y_strong, _ = self.model(mel) | |
return y_strong | |
def separate_params(self): | |
pt_params = [] | |
seq_params = [] | |
head_params = [] | |
for name, p in self.named_parameters(): | |
name = name[len("model."):] | |
if name.startswith('model'): | |
# the transformer | |
pt_params.append(p) | |
elif name.startswith('seq_model'): | |
# the optional sequence model | |
seq_params.append(p) | |
elif name.startswith('strong_head') or name.startswith('weak_head'): | |
# the prediction head | |
head_params.append(p) | |
else: | |
raise ValueError(f"Unexpected key in model: {name}") | |
if self.model.has_separate_params(): | |
# split parameters into groups according to their depth in the network | |
# based on this, we can apply layer-wise learning rate decay | |
pt_params = self.model.separate_params() | |
else: | |
if self.config.lr_decay != 1.0: | |
raise ValueError(f"Model has no separate_params function. Can't apply layer-wise lr decay, but " | |
f"learning rate decay is set to {self.config.lr_decay}.") | |
return pt_params, seq_params, head_params | |
def get_optimizer( | |
self, | |
lr, | |
lr_decay=1.0, | |
transformer_lr=None, | |
transformer_frozen=False, | |
adamw=False, | |
weight_decay=0.01, | |
betas=(0.9, 0.999) | |
): | |
pt_params, seq_params, head_params = self.separate_params() | |
param_groups = [ | |
{'params': head_params, 'lr': lr}, # model head (besides base model and seq model) | |
] | |
if transformer_frozen: | |
for p in pt_params + seq_params: | |
if isinstance(p, list): | |
for p_i in p: | |
p_i.detach_() | |
else: | |
p.detach_() | |
else: | |
if transformer_lr is None: | |
transformer_lr = lr | |
if isinstance(pt_params, list) and isinstance(pt_params[0], list): | |
# apply lr decay | |
scale_lrs = [transformer_lr * (lr_decay ** i) for i in range(1, len(pt_params) + 1)] | |
param_groups = param_groups + [{"params": pt_params[i], "lr": scale_lrs[i]} for i in | |
range(len(pt_params))] | |
else: | |
param_groups.append( | |
{'params': pt_params, 'lr': transformer_lr}, # pretrained model | |
) | |
param_groups.append( | |
{'params': seq_params, 'lr': lr}, # pretrained model | |
) | |
# do not apply weight decay to biases and batch norms | |
param_groups_split = [] | |
for param_group in param_groups: | |
params_1D, params_2D = [], [] | |
lr = param_group['lr'] | |
for param in param_group['params']: | |
if param.ndimension() >= 2: | |
params_2D.append(param) | |
elif param.ndimension() <= 1: | |
params_1D.append(param) | |
param_groups_split += [{'params': params_2D, 'lr': lr, 'weight_decay': weight_decay}, | |
{'params': params_1D, 'lr': lr}] | |
if weight_decay > 0: | |
assert adamw | |
if adamw: | |
print(f"\nUsing adamw weight_decay={weight_decay}!\n") | |
return torch.optim.AdamW(param_groups_split, lr=lr, weight_decay=weight_decay, betas=betas) | |
return torch.optim.Adam(param_groups_split, lr=lr, betas=betas) | |
def get_lr_scheduler( | |
self, | |
optimizer, | |
num_training_steps, | |
schedule_mode="cos", | |
gamma: float = 0.999996, | |
num_warmup_steps=4000, | |
lr_end=1e-7, | |
): | |
if schedule_mode in {"exp"}: | |
return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma) | |
if schedule_mode in {"cosine", "cos"}: | |
return transformers.get_cosine_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
) | |
if schedule_mode in {"linear"}: | |
print("Linear schedule!") | |
return transformers.get_polynomial_decay_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=num_warmup_steps, | |
num_training_steps=num_training_steps, | |
power=1.0, | |
lr_end=lr_end, | |
) | |
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown.") | |
def configure_optimizers(self): | |
""" | |
This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined. | |
The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself). | |
:return: dict containing optimizer and learning rate scheduler | |
""" | |
optimizer = self.get_optimizer(self.config.max_lr, | |
lr_decay=self.config.lr_decay, | |
transformer_lr=self.config.transformer_lr, | |
transformer_frozen=self.config.transformer_frozen, | |
adamw=False if self.config.no_adamw else True, | |
weight_decay=self.config.weight_decay) | |
num_training_steps = self.trainer.estimated_stepping_batches | |
scheduler = self.get_lr_scheduler(optimizer, num_training_steps, | |
schedule_mode=self.config.schedule_mode, | |
lr_end=self.config.lr_end) | |
lr_scheduler_config = { | |
"scheduler": scheduler, | |
"interval": "step", | |
"frequency": 1 | |
} | |
return [optimizer], [lr_scheduler_config] | |
def training_step(self, train_batch, batch_idx): | |
""" | |
:param train_batch: contains one batch from train dataloader | |
:param batch_idx | |
:return: a dict containing at least loss that is used to update model parameters, can also contain | |
other items that can be processed in 'training_epoch_end' to log other metrics than loss | |
""" | |
audios, labels, fnames, timestamps = train_batch | |
if self.config.transformer_frozen: | |
self.model.model.eval() | |
self.model.seq_model.eval() | |
mel = self.model.mel_forward(audios) | |
# time rolling | |
if self.config.frame_shift_range > 0: | |
mel, labels = frame_shift( | |
mel, | |
labels, | |
shift_range=self.config.frame_shift_range | |
) | |
# mixup | |
if self.config.mixup_p > random.random(): | |
mel, labels = mixup( | |
mel, | |
targets=labels | |
) | |
# mixstyle | |
if self.config.mixstyle_p > random.random(): | |
mel = mixstyle( | |
mel | |
) | |
# time masking | |
if self.config.max_time_mask_size > 0: | |
mel, labels, pseudo_labels = time_mask( | |
mel, | |
labels, | |
max_mask_ratio=self.config.max_time_mask_size | |
) | |
# frequency masking | |
if self.config.filter_augment_p > random.random(): | |
mel, _ = filter_augmentation( | |
mel | |
) | |
# frequency warping | |
if self.config.freq_warp_p > random.random(): | |
mel = mel.squeeze(1) | |
mel = self.freq_warp(mel) | |
mel = mel.unsqueeze(1) | |
# forward through network; use strong head | |
y_hat_strong, _ = self.model(mel) | |
loss = self.strong_loss(y_hat_strong, labels) | |
# logging | |
self.log('epoch', self.current_epoch) | |
for i, param_group in enumerate(self.trainer.optimizers[0].param_groups): | |
self.log(f'trainer/lr_optimizer_{i}', param_group['lr']) | |
self.log("train/loss", loss.detach().cpu(), prog_bar=True) | |
return loss | |
def _score_step(self, batch): | |
audios, labels, fnames, timestamps = batch | |
strong_preds = self.forward(audios) | |
self.preds.append(strong_preds) | |
self.tgts.append(labels) | |
self.fnames.append(fnames) | |
self.timestamps.append(timestamps) | |
def _score_epoch_end(self, name="val"): | |
preds = torch.cat(self.preds) | |
tgts = torch.cat(self.tgts) | |
fnames = [item for sublist in self.fnames for item in sublist] | |
timestamps = torch.cat(self.timestamps) | |
val_loss = self.strong_loss(preds, tgts) | |
self.log(f"{name}/loss", val_loss, prog_bar=True) | |
# the following function expects one prediction per timestamp (sequence dimension must be flattened) | |
seq_len = preds.size(-1) | |
preds = rearrange(preds, 'bs c t -> (bs t) c').float() | |
timestamps = rearrange(timestamps, 'bs t -> (bs t)').float() | |
fnames = [fname for fname in fnames for _ in range(seq_len)] | |
predicted_events_by_postprocessing = get_events_for_all_files( | |
preds, | |
fnames, | |
timestamps, | |
self.idx_to_label, | |
self.postprocessing_grid | |
) | |
# we only have one postprocessing configurations (aligned with HEAR challenge) | |
key = list(predicted_events_by_postprocessing.keys())[0] | |
predicted_events = predicted_events_by_postprocessing[key] | |
# load ground truth for test fold | |
task_path = Path(self.config.task_path) | |
test_target_events = combine_target_events(["valid" if name == "val" else "test"], task_path) | |
onset_fms = self.event_onset_200ms_fms(predicted_events, test_target_events) | |
onset_fms_50 = self.event_onset_50ms_fms(predicted_events, test_target_events) | |
segment_1s_er = self.segment_1s_er(predicted_events, test_target_events) | |
self.log(f"{name}/onset_fms", onset_fms[0][1]) | |
self.log(f"{name}/onset_fms_50", onset_fms_50[0][1]) | |
self.log(f"{name}/segment_1s_er", segment_1s_er[0][1]) | |
# free buffers | |
self.preds, self.tgts, self.fnames, self.timestamps = [], [], [], [] | |
def validation_step(self, batch, batch_idx): | |
self._score_step(batch) | |
def on_validation_epoch_end(self): | |
self._score_epoch_end(name="val") | |
def test_step(self, batch, batch_idx): | |
self._score_step(batch) | |
def on_test_epoch_end(self): | |
self._score_epoch_end(name="test") | |
def train(config): | |
# Example for fine-tuning pre-trained transformers on a downstream task. | |
# logging is done using wandb | |
wandb_logger = WandbLogger( | |
project="PTSED", | |
notes="Downstream Training on office sound event detection.", | |
tags=["DCASE 2016 Task 2", "Sound Event Detection"], | |
config=config, | |
name=config.experiment_name | |
) | |
train_set = get_training_dataset(config.task_path, wavmix_p=config.wavmix_p) | |
val_ds = get_validation_dataset(config.task_path) | |
test_ds = get_test_dataset(config.task_path) | |
# train dataloader | |
train_dl = DataLoader(dataset=train_set, | |
worker_init_fn=worker_init_fn, | |
num_workers=config.num_workers, | |
batch_size=config.batch_size, | |
shuffle=True) | |
# validation dataloader | |
valid_dl = DataLoader(dataset=val_ds, | |
worker_init_fn=worker_init_fn, | |
num_workers=config.num_workers, | |
batch_size=config.batch_size, | |
shuffle=False, | |
drop_last=False) | |
# test dataloader | |
test_dl = DataLoader(dataset=test_ds, | |
worker_init_fn=worker_init_fn, | |
num_workers=config.num_workers, | |
batch_size=config.batch_size, | |
shuffle=False, | |
drop_last=False) | |
# create pytorch lightening module | |
pl_module = PLModule(config) | |
# create the pytorch lightening trainer by specifying the number of epochs to train, the logger, | |
# on which kind of device(s) to train and possible callbacks | |
trainer = pl.Trainer(max_epochs=config.n_epochs, | |
logger=wandb_logger, | |
accelerator='auto', | |
devices=config.num_devices, | |
precision=config.precision, | |
num_sanity_val_steps=0, | |
check_val_every_n_epoch=config.check_val_every_n_epoch | |
) | |
# start training and validation for the specified number of epochs | |
trainer.fit( | |
pl_module, | |
train_dataloaders=train_dl, | |
val_dataloaders=valid_dl, | |
) | |
test_results = trainer.test(pl_module, dataloaders=test_dl) | |
print(test_results) | |
wandb.finish() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Configuration Parser. ') | |
# general | |
parser.add_argument('--task_path', type=str, required=True) | |
parser.add_argument('--experiment_name', type=str, default="DCASE2016Task2") | |
parser.add_argument('--batch_size', type=int, default=256) | |
parser.add_argument('--num_workers', type=int, default=16) | |
parser.add_argument('--num_devices', type=int, default=1) | |
parser.add_argument('--precision', type=int, default=16) | |
parser.add_argument('--check_val_every_n_epoch', type=int, default=10) | |
# model | |
parser.add_argument('--model_name', type=str, | |
choices=["ATST-F", "BEATs", "fpasst", "M2D", "ASIT"], | |
default="ATST-F") # used also for training | |
# "scratch" = no pretraining | |
# "ssl" = SSL pre-trained | |
# "weak" = AudioSet Weak pre-trained | |
# "strong" = AudioSet Strong pre-trained | |
parser.add_argument('--pretrained', type=str, choices=["scratch", "ssl", "weak", "strong"], | |
default="strong") | |
parser.add_argument('--seq_model_type', type=str, choices=["rnn"], | |
default=None) | |
parser.add_argument('--n_classes', type=int, default=11) | |
# training | |
parser.add_argument('--n_epochs', type=int, default=300) | |
# augmentation | |
parser.add_argument('--wavmix_p', type=float, default=0.5) | |
parser.add_argument('--freq_warp_p', type=float, default=0.0) | |
parser.add_argument('--filter_augment_p', type=float, default=0.0) | |
parser.add_argument('--frame_shift_range', type=float, default=0.0) # in seconds | |
parser.add_argument('--mixup_p', type=float, default=0.5) | |
parser.add_argument('--mixstyle_p', type=float, default=0.0) | |
parser.add_argument('--max_time_mask_size', type=float, default=0.0) | |
# optimizer | |
parser.add_argument('--no_adamw', action='store_true', default=False) | |
parser.add_argument('--weight_decay', type=float, default=0.001) | |
parser.add_argument('--transformer_frozen', action='store_true', dest='transformer_frozen', | |
default=False, | |
help='Disable training for the transformer.') | |
# lr schedule | |
parser.add_argument('--schedule_mode', type=str, default="cos") | |
parser.add_argument('--max_lr', type=float, default=1.06e-4) | |
parser.add_argument('--transformer_lr', type=float, default=None) | |
parser.add_argument('--lr_decay', type=float, default=1.0) | |
parser.add_argument('--lr_end', type=float, default=1e-7) | |
parser.add_argument('--warmup_steps', type=int, default=100) | |
args = parser.parse_args() | |
train(args) | |