NSynth-5K-pretrained-sed / data_util /audioset_strong.py
sohamc10's picture
gradio app
9b0d6c2
import os
from time import perf_counter
import datasets
import numpy as np
import pandas as pd
import torch
from torch.utils.data import (
Dataset as TorchDataset,
DistributedSampler,
WeightedRandomSampler,
)
from data_util.audioset_classes import as_strong_train_classes
from data_util.transforms import (
Mp3DecodeTransform,
SequentialTransform,
AddPseudoLabelsTransform,
strong_label_transform,
target_transform
)
logger = datasets.logging.get_logger(__name__)
def init_hf_config(max_shard_size="2GB", verbose=True, in_mem_max=None):
datasets.config.MAX_SHARD_SIZE = max_shard_size
if verbose:
datasets.logging.set_verbosity_info()
if in_mem_max is not None:
datasets.config.IN_MEMORY_MAX_SIZE = in_mem_max
def get_hf_local_path(path, local_datasets_path=None):
if local_datasets_path is None:
local_datasets_path = os.environ.get(
"HF_DATASETS_LOCAL",
os.path.join(os.environ.get("HF_DATASETS_CACHE"), "../local"),
)
path = os.path.join(local_datasets_path, path)
return path
class catchtime:
# context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time
def __init__(self, debug_print="Time", logger=logger):
self.debug_print = debug_print
self.logger = logger
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
self.time = perf_counter() - self.start
readout = f"{self.debug_print}: {self.time:.3f} seconds"
self.logger.info(readout)
def merge_overlapping_events(sample):
events = pd.DataFrame(sample['events'][0])
events = events.sort_values(by='onset')
sample['events'] = [None]
for l in events['event_label'].unique():
rows = []
for i, r in events.loc[events['event_label'] == l].iterrows():
if len(rows) == 0 or rows[-1]['offset'] < r['onset']:
rows.append(r)
else:
onset = min(rows[-1]['onset'], r['onset'])
offset = max(rows[-1]['offset'], r['offset'])
rows[-1]['onset'] = onset
rows[-1]['offset'] = offset
if sample["events"][0] is None:
sample['events'][0] = pd.DataFrame(rows)
else:
sample["events"][0] = pd.concat([sample['events'][0], pd.DataFrame(rows)])
return sample
def get_training_dataset(
label_encoder,
audio_length=10.0,
sample_rate=16000,
wavmix_p=0.0,
pseudo_labels_file=None,
):
init_hf_config()
decode_transform = Mp3DecodeTransform(
sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
)
ds_list = []
with catchtime("Loading audioset_strong"):
as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
# label encode transformation
if label_encoder is not None:
# set list of label names to be encoded
label_encoder.labels = as_strong_train_classes
encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
else:
encode_label_fun = lambda x: x
as_transforms = [
decode_transform,
merge_overlapping_events,
encode_label_fun,
target_transform,
]
if pseudo_labels_file:
as_transforms.append(AddPseudoLabelsTransform(pseudo_labels_file=pseudo_labels_file).add_pseudo_label_transform)
as_ds.set_transform(SequentialTransform(as_transforms))
ds_list.append(as_ds["balanced_train"])
ds_list.append(as_ds["unbalanced_train"])
dataset = torch.utils.data.ConcatDataset(ds_list)
if wavmix_p > 0:
print("Using Wavmix!")
dataset = MixupDataset(dataset, rate=wavmix_p)
return dataset
def get_eval_dataset(
label_encoder,
audio_length=10.0,
sample_rate=16000
):
init_hf_config()
ds_list = []
decode_transform = Mp3DecodeTransform(
sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
)
with catchtime(f"Loading audioset:"):
as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
# label encode transformation
if label_encoder is not None:
label_encoder.labels = as_strong_train_classes
encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
else:
encode_label_fun = lambda x: x
as_transforms = [
decode_transform,
merge_overlapping_events,
encode_label_fun,
target_transform
]
as_ds.set_transform(SequentialTransform(as_transforms))
as_ds_eval = (
as_ds["eval"]
)
ds_list.append(as_ds_eval)
dataset = torch.utils.data.ConcatDataset(ds_list)
return dataset
def get_full_dataset(label_encoder, audio_length=10.0, sample_rate=16000):
init_hf_config()
decode_transform = Mp3DecodeTransform(
sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename"
)
with catchtime(f"Loading audioset:"):
as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong"))
# label encode transformation
if label_encoder is not None:
label_encoder.labels = as_strong_train_classes
encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder)
else:
encode_label_fun = lambda x: x
as_transforms = [
decode_transform,
merge_overlapping_events,
encode_label_fun,
]
as_ds.set_transform(SequentialTransform(as_transforms))
ds_list = []
ds_list.append(as_ds["balanced_train"])
ds_list.append(as_ds["unbalanced_train"])
ds_list.append(as_ds["eval"])
dataset = torch.utils.data.ConcatDataset(ds_list)
return dataset
def get_uniform_sample_weights(dataset):
"""
:return: float tensor of shape len(full_training_set) representing the weights of each sample.
"""
return torch.ones(len(dataset)).float()
def get_temporal_count_balanced_sample_weights(dataset, sample_weight_offset=30,
save_folder="/share/rk8/shared/as_strong"):
"""
:return: float tensor of shape len(full_training_set) representing the weights of each sample.
"""
# the order of balanced_train_hdf5, unbalanced_train_hdf5 is important.
# should match get_full_training_set
os.makedirs(save_folder, exist_ok=True)
save_file = os.path.join(save_folder, f"weights_temporal_count_offset_{sample_weight_offset}.pt")
if os.path.exists(save_file):
return torch.load(save_file)
from tqdm import tqdm
all_y = []
for sample in tqdm(dataset, desc="Calculating sample weights."):
all_y.append(sample["event_count"])
all_y = torch.from_numpy(np.stack(all_y, axis=0))
per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class
per_class = sample_weight_offset + per_class # offset low freq classes
if sample_weight_offset > 0:
print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}")
per_class_weights = 1000. / per_class
all_weight = all_y * per_class_weights
all_weight = all_weight.sum(dim=1)
torch.save(all_weight, save_file)
return all_weight
class MixupDataset(TorchDataset):
""" Mixing Up wave forms
"""
def __init__(self, dataset, beta=2, rate=0.5):
self.beta = beta
self.rate = rate
self.dataset = dataset
print(f"Mixing up waveforms from dataset of len {len(dataset)}")
def __getitem__(self, index):
if torch.rand(1) < self.rate:
batch1 = self.dataset[index]
idx2 = torch.randint(len(self.dataset), (1,)).item()
batch2 = self.dataset[idx2]
x1, x2 = batch1['audio'], batch2['audio']
y1, y2 = batch1['strong'], batch2['strong']
if 'pseudo_strong' in batch1:
p1, p2 = batch1['pseudo_strong'], batch2['pseudo_strong']
l = np.random.beta(self.beta, self.beta)
l = max(l, 1. - l)
x1 = x1 - x1.mean()
x2 = x2 - x2.mean()
x = (x1 * l + x2 * (1. - l))
x = x - x.mean()
batch1['audio'] = x
batch1['strong'] = (y1 * l + y2 * (1. - l))
if 'pseudo_strong' in batch1:
batch1['pseudo_strong'] = (p1 * l + p2 * (1. - l))
return batch1
return self.dataset[index]
def __len__(self):
return len(self.dataset)
class DistributedSamplerWrapper(DistributedSampler):
def __init__(
self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True
):
super(DistributedSamplerWrapper, self).__init__(
dataset, num_replicas, rank, shuffle
)
# source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238
self.sampler = sampler
def __iter__(self):
if self.sampler.generator is None:
self.sampler.generator = torch.Generator()
self.sampler.generator.manual_seed(self.seed + self.epoch)
indices = list(self.sampler)
if self.epoch < 2:
logger.info(
f"\n DistributedSamplerWrapper (rank {self.rank}) : {indices[:3]} \n\n"
)
indices = indices[self.rank : self.total_size : self.num_replicas]
return iter(indices)
def get_weighted_sampler(
samples_weights,
epoch_len=100_000,
sampler_replace=False,
):
num_nodes = int(os.environ.get("WORLD_SIZE", 1))
ddp = int(os.environ.get("DDP", 1))
num_nodes = max(ddp, num_nodes)
rank = int(os.environ.get("NODE_RANK", 0))
return DistributedSamplerWrapper(
sampler=WeightedRandomSampler(
samples_weights, num_samples=epoch_len, replacement=sampler_replace
),
dataset=range(epoch_len),
num_replicas=num_nodes,
rank=rank,
)
if __name__ == "__main__":
from helpers.encode import ManyHotEncoder
encoder = ManyHotEncoder([], 10., 160, net_pooling=4, fs=16_000)
train_ds = get_training_dataset(
encoder, audio_length=10.0, sample_rate=16_000
)
valid_ds = get_eval_dataset(
encoder, audio_length=10.0, sample_rate=16_000
)
print("Len train dataset: ", len(train_ds))
print("Len valid dataset: ", len(valid_ds))