import os from time import perf_counter import datasets import numpy as np import pandas as pd import torch from torch.utils.data import ( Dataset as TorchDataset, DistributedSampler, WeightedRandomSampler, ) from data_util.audioset_classes import as_strong_train_classes from data_util.transforms import ( Mp3DecodeTransform, SequentialTransform, AddPseudoLabelsTransform, strong_label_transform, target_transform ) logger = datasets.logging.get_logger(__name__) def init_hf_config(max_shard_size="2GB", verbose=True, in_mem_max=None): datasets.config.MAX_SHARD_SIZE = max_shard_size if verbose: datasets.logging.set_verbosity_info() if in_mem_max is not None: datasets.config.IN_MEMORY_MAX_SIZE = in_mem_max def get_hf_local_path(path, local_datasets_path=None): if local_datasets_path is None: local_datasets_path = os.environ.get( "HF_DATASETS_LOCAL", os.path.join(os.environ.get("HF_DATASETS_CACHE"), "../local"), ) path = os.path.join(local_datasets_path, path) return path class catchtime: # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time def __init__(self, debug_print="Time", logger=logger): self.debug_print = debug_print self.logger = logger def __enter__(self): self.start = perf_counter() return self def __exit__(self, type, value, traceback): self.time = perf_counter() - self.start readout = f"{self.debug_print}: {self.time:.3f} seconds" self.logger.info(readout) def merge_overlapping_events(sample): events = pd.DataFrame(sample['events'][0]) events = events.sort_values(by='onset') sample['events'] = [None] for l in events['event_label'].unique(): rows = [] for i, r in events.loc[events['event_label'] == l].iterrows(): if len(rows) == 0 or rows[-1]['offset'] < r['onset']: rows.append(r) else: onset = min(rows[-1]['onset'], r['onset']) offset = max(rows[-1]['offset'], r['offset']) rows[-1]['onset'] = onset rows[-1]['offset'] = offset if sample["events"][0] is None: sample['events'][0] = pd.DataFrame(rows) else: sample["events"][0] = pd.concat([sample['events'][0], pd.DataFrame(rows)]) return sample def get_training_dataset( label_encoder, audio_length=10.0, sample_rate=16000, wavmix_p=0.0, pseudo_labels_file=None, ): init_hf_config() decode_transform = Mp3DecodeTransform( sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename" ) ds_list = [] with catchtime("Loading audioset_strong"): as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong")) # label encode transformation if label_encoder is not None: # set list of label names to be encoded label_encoder.labels = as_strong_train_classes encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder) else: encode_label_fun = lambda x: x as_transforms = [ decode_transform, merge_overlapping_events, encode_label_fun, target_transform, ] if pseudo_labels_file: as_transforms.append(AddPseudoLabelsTransform(pseudo_labels_file=pseudo_labels_file).add_pseudo_label_transform) as_ds.set_transform(SequentialTransform(as_transforms)) ds_list.append(as_ds["balanced_train"]) ds_list.append(as_ds["unbalanced_train"]) dataset = torch.utils.data.ConcatDataset(ds_list) if wavmix_p > 0: print("Using Wavmix!") dataset = MixupDataset(dataset, rate=wavmix_p) return dataset def get_eval_dataset( label_encoder, audio_length=10.0, sample_rate=16000 ): init_hf_config() ds_list = [] decode_transform = Mp3DecodeTransform( sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename" ) with catchtime(f"Loading audioset:"): as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong")) # label encode transformation if label_encoder is not None: label_encoder.labels = as_strong_train_classes encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder) else: encode_label_fun = lambda x: x as_transforms = [ decode_transform, merge_overlapping_events, encode_label_fun, target_transform ] as_ds.set_transform(SequentialTransform(as_transforms)) as_ds_eval = ( as_ds["eval"] ) ds_list.append(as_ds_eval) dataset = torch.utils.data.ConcatDataset(ds_list) return dataset def get_full_dataset(label_encoder, audio_length=10.0, sample_rate=16000): init_hf_config() decode_transform = Mp3DecodeTransform( sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename" ) with catchtime(f"Loading audioset:"): as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong")) # label encode transformation if label_encoder is not None: label_encoder.labels = as_strong_train_classes encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder) else: encode_label_fun = lambda x: x as_transforms = [ decode_transform, merge_overlapping_events, encode_label_fun, ] as_ds.set_transform(SequentialTransform(as_transforms)) ds_list = [] ds_list.append(as_ds["balanced_train"]) ds_list.append(as_ds["unbalanced_train"]) ds_list.append(as_ds["eval"]) dataset = torch.utils.data.ConcatDataset(ds_list) return dataset def get_uniform_sample_weights(dataset): """ :return: float tensor of shape len(full_training_set) representing the weights of each sample. """ return torch.ones(len(dataset)).float() def get_temporal_count_balanced_sample_weights(dataset, sample_weight_offset=30, save_folder="/share/rk8/shared/as_strong"): """ :return: float tensor of shape len(full_training_set) representing the weights of each sample. """ # the order of balanced_train_hdf5, unbalanced_train_hdf5 is important. # should match get_full_training_set os.makedirs(save_folder, exist_ok=True) save_file = os.path.join(save_folder, f"weights_temporal_count_offset_{sample_weight_offset}.pt") if os.path.exists(save_file): return torch.load(save_file) from tqdm import tqdm all_y = [] for sample in tqdm(dataset, desc="Calculating sample weights."): all_y.append(sample["event_count"]) all_y = torch.from_numpy(np.stack(all_y, axis=0)) per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class per_class = sample_weight_offset + per_class # offset low freq classes if sample_weight_offset > 0: print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}") per_class_weights = 1000. / per_class all_weight = all_y * per_class_weights all_weight = all_weight.sum(dim=1) torch.save(all_weight, save_file) return all_weight class MixupDataset(TorchDataset): """ Mixing Up wave forms """ def __init__(self, dataset, beta=2, rate=0.5): self.beta = beta self.rate = rate self.dataset = dataset print(f"Mixing up waveforms from dataset of len {len(dataset)}") def __getitem__(self, index): if torch.rand(1) < self.rate: batch1 = self.dataset[index] idx2 = torch.randint(len(self.dataset), (1,)).item() batch2 = self.dataset[idx2] x1, x2 = batch1['audio'], batch2['audio'] y1, y2 = batch1['strong'], batch2['strong'] if 'pseudo_strong' in batch1: p1, p2 = batch1['pseudo_strong'], batch2['pseudo_strong'] l = np.random.beta(self.beta, self.beta) l = max(l, 1. - l) x1 = x1 - x1.mean() x2 = x2 - x2.mean() x = (x1 * l + x2 * (1. - l)) x = x - x.mean() batch1['audio'] = x batch1['strong'] = (y1 * l + y2 * (1. - l)) if 'pseudo_strong' in batch1: batch1['pseudo_strong'] = (p1 * l + p2 * (1. - l)) return batch1 return self.dataset[index] def __len__(self): return len(self.dataset) class DistributedSamplerWrapper(DistributedSampler): def __init__( self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True ): super(DistributedSamplerWrapper, self).__init__( dataset, num_replicas, rank, shuffle ) # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238 self.sampler = sampler def __iter__(self): if self.sampler.generator is None: self.sampler.generator = torch.Generator() self.sampler.generator.manual_seed(self.seed + self.epoch) indices = list(self.sampler) if self.epoch < 2: logger.info( f"\n DistributedSamplerWrapper (rank {self.rank}) : {indices[:3]} \n\n" ) indices = indices[self.rank : self.total_size : self.num_replicas] return iter(indices) def get_weighted_sampler( samples_weights, epoch_len=100_000, sampler_replace=False, ): num_nodes = int(os.environ.get("WORLD_SIZE", 1)) ddp = int(os.environ.get("DDP", 1)) num_nodes = max(ddp, num_nodes) rank = int(os.environ.get("NODE_RANK", 0)) return DistributedSamplerWrapper( sampler=WeightedRandomSampler( samples_weights, num_samples=epoch_len, replacement=sampler_replace ), dataset=range(epoch_len), num_replicas=num_nodes, rank=rank, ) if __name__ == "__main__": from helpers.encode import ManyHotEncoder encoder = ManyHotEncoder([], 10., 160, net_pooling=4, fs=16_000) train_ds = get_training_dataset( encoder, audio_length=10.0, sample_rate=16_000 ) valid_ds = get_eval_dataset( encoder, audio_length=10.0, sample_rate=16_000 ) print("Len train dataset: ", len(train_ds)) print("Len valid dataset: ", len(valid_ds))