sohamc10's picture
gradio app
9b0d6c2
import os
import datasets
import h5py
import numpy as np
import pandas as pd
import torch
import torchaudio
from data_util.audioset_classes import as_strong_train_classes
## Transforms with a similar style to https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py
logger = datasets.logging.get_logger(__name__)
def target_transform(sample):
del sample["labels"]
del sample["label_ids"]
return sample
def strong_label_transform(sample, strong_label_encoder=None):
assert strong_label_encoder is not None
events = pd.DataFrame(sample['events'][0])
events = events[events['event_label'].isin(set(as_strong_train_classes))]
strong = strong_label_encoder.encode_strong_df(events).T
sample["strong"] = [strong]
sample["event_count"] = [strong.sum(1)]
# encode ground truth events as string - we will use this for evaluation
sample["gt_string"] = ["++".join([";;".join([str(e[0]), str(e[1]), e[2]]) for e in
zip(sample['events'][0]['onset'], sample['events'][0]['offset'],
sample['events'][0]['event_label'])])]
del sample['events']
return sample
class AddPseudoLabelsTransform:
def __init__(self, pseudo_labels_file):
self.pseudo_labels_file = pseudo_labels_file
if self.pseudo_labels_file is not None:
# fetch dict of positions for each example
self.ex2pseudo_idx = {}
f = h5py.File(self.pseudo_labels_file, "r")
for i, fname in enumerate(f["filenames"]):
self.ex2pseudo_idx[fname.decode("UTF-8")] = i
self._opened_pseudo_hdf5 = None
@property
def pseudo_hdf5_file(self):
if self._opened_pseudo_hdf5 is None:
self._opened_pseudo_hdf5 = h5py.File(self.pseudo_labels_file, "r")
return self._opened_pseudo_hdf5
def add_pseudo_label_transform(self, sample):
indices = [self.ex2pseudo_idx[fn.rstrip(".mp3")] for fn in sample['filename']]
pseudo_strong = [torch.from_numpy(np.stack(self.pseudo_hdf5_file["strong_logits"][index])).float()
for index in indices]
pseudo_strong = [torch.sigmoid(pseudo_strong[i]) for i in range(len(pseudo_strong))]
sample['pseudo_strong'] = pseudo_strong
return sample
class SequentialTransform:
"""Apply a sequence of transforms to a batch."""
def __init__(self, transforms):
"""
Args:
transforms: list of transforms to apply
"""
self.transforms = transforms
def append(self, transform):
self.transforms.append(transform)
def __call__(self, batch):
for t in self.transforms:
batch = t(batch)
return batch
class Mp3DecodeTransform:
def __init__(
self,
mp3_bytes_key="mp3_bytes",
audio_key="audio",
sample_rate=32000,
max_length=10.0,
min_length=None,
random_sample_crop=True,
allow_resample=True,
resampling_method="sinc_interp_kaiser",
keep_mp3_bytes=False,
debug_info_key=None,
):
"""Decode mp3 bytes to audio waveform
Args:
mp3_bytes_key (str, optional): The key to mp3 bytes in the input batch. Defaults to "mp3_bytes".
audio_key (str, optional): The key to save the decoded audio in the output batch. Defaults to "audio".
sample_rate (int, optional): The expected output audio_key. Defaults to 32000.
max_length (int, float, optional): the maximum output audio length in seconds if float, otherwise in samples. Defaults to 10.
min_length (int, optional): the minimum output audio length in seconds. Defaults to max_length.
random_sample_crop (bool, optional): Randomly crop the audio to max_length if its longer otherwise return the first crop. Defaults to True.
allow_resample (bool, optional): Resample the singal if the sampling rate don't match. Defaults to True.
resampling_method (str, optional): reampling method from torchaudio.transforms.Resample . Defaults to "sinc_interp_kaiser".
keep_mp3_bytes (bool, optional): keep the original bytes in the output dict. Defaults to False.
Raises:
Exception: if minimp3py is not installed
"""
self.mp3_bytes_key = mp3_bytes_key
self.audio_key = audio_key
self.sample_rate = sample_rate
self.max_length = max_length
if min_length is None:
min_length = max_length
self.min_length = min_length
self.random_sample_crop = random_sample_crop
self.allow_resample = allow_resample
self.resampling_method = resampling_method
self.keep_mp3_bytes = keep_mp3_bytes
self.debug_info_key = debug_info_key
self.resamplers_cache = {}
try:
import minimp3py # noqa: F401
except:
raise Exception(
"minimp3py is not installed, please install it using: `CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip`"
)
def __call__(self, batch):
import minimp3py
data_list = batch[self.mp3_bytes_key]
if self.debug_info_key is not None:
file_name_list = batch[self.debug_info_key]
else:
file_name_list = range(len(data_list))
audio_list = []
for data, file_name in zip(data_list, file_name_list):
try:
duration, ch, sr = minimp3py.probe(data)
if isinstance(self.max_length, float):
max_length = int(self.max_length * sr)
else:
max_length = int(self.max_length * sr // self.sample_rate)
offset = 0
if self.random_sample_crop and duration > max_length:
max_offset = max(int(duration - max_length), 0) + 1
offset = torch.randint(max_offset, (1,)).item()
waveform, _ = minimp3py.read(data, start=offset, length=max_length)
waveform = waveform[:, 0] # 0 for the first channel only
if waveform.dtype != "float32":
raise RuntimeError("Unexpected wave type")
waveform = torch.from_numpy(waveform)
if len(waveform) == 0:
logger.warning(
f"Empty waveform for {file_name}, duration {duration}, offset {offset}, max_length {max_length}, sr {sr}, ch {ch}"
)
elif sr != self.sample_rate:
assert self.allow_resample, f"Unexpected sample rate {sr} instead of {self.sample_rate} at {file_name}"
if self.resamplers_cache.get(sr) is None:
self.resamplers_cache[sr] = torchaudio.transforms.Resample(
sr,
self.sample_rate,
resampling_method=self.resampling_method,
)
waveform = self.resamplers_cache[sr](waveform)
min_length = self.min_length
if isinstance(self.min_length, float):
min_length = int(self.min_length * self.sample_rate)
if min_length is not None and len(waveform) < min_length:
waveform = torch.concatenate(
(
waveform,
torch.zeros(
min_length - len(waveform),
dtype=waveform.dtype,
device=waveform.device,
),
),
dim=0,
)
audio_list.append(waveform)
except Exception as e:
print(f"Error decoding {file_name}: {e}")
raise e
batch[self.audio_key] = audio_list
batch["sampling_rate"] = [self.sample_rate] * len(audio_list)
if not self.keep_mp3_bytes:
del batch[self.mp3_bytes_key]
return batch