Spaces:
Sleeping
Sleeping
import os | |
import datasets | |
import h5py | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchaudio | |
from data_util.audioset_classes import as_strong_train_classes | |
## Transforms with a similar style to https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py | |
logger = datasets.logging.get_logger(__name__) | |
def target_transform(sample): | |
del sample["labels"] | |
del sample["label_ids"] | |
return sample | |
def strong_label_transform(sample, strong_label_encoder=None): | |
assert strong_label_encoder is not None | |
events = pd.DataFrame(sample['events'][0]) | |
events = events[events['event_label'].isin(set(as_strong_train_classes))] | |
strong = strong_label_encoder.encode_strong_df(events).T | |
sample["strong"] = [strong] | |
sample["event_count"] = [strong.sum(1)] | |
# encode ground truth events as string - we will use this for evaluation | |
sample["gt_string"] = ["++".join([";;".join([str(e[0]), str(e[1]), e[2]]) for e in | |
zip(sample['events'][0]['onset'], sample['events'][0]['offset'], | |
sample['events'][0]['event_label'])])] | |
del sample['events'] | |
return sample | |
class AddPseudoLabelsTransform: | |
def __init__(self, pseudo_labels_file): | |
self.pseudo_labels_file = pseudo_labels_file | |
if self.pseudo_labels_file is not None: | |
# fetch dict of positions for each example | |
self.ex2pseudo_idx = {} | |
f = h5py.File(self.pseudo_labels_file, "r") | |
for i, fname in enumerate(f["filenames"]): | |
self.ex2pseudo_idx[fname.decode("UTF-8")] = i | |
self._opened_pseudo_hdf5 = None | |
def pseudo_hdf5_file(self): | |
if self._opened_pseudo_hdf5 is None: | |
self._opened_pseudo_hdf5 = h5py.File(self.pseudo_labels_file, "r") | |
return self._opened_pseudo_hdf5 | |
def add_pseudo_label_transform(self, sample): | |
indices = [self.ex2pseudo_idx[fn.rstrip(".mp3")] for fn in sample['filename']] | |
pseudo_strong = [torch.from_numpy(np.stack(self.pseudo_hdf5_file["strong_logits"][index])).float() | |
for index in indices] | |
pseudo_strong = [torch.sigmoid(pseudo_strong[i]) for i in range(len(pseudo_strong))] | |
sample['pseudo_strong'] = pseudo_strong | |
return sample | |
class SequentialTransform: | |
"""Apply a sequence of transforms to a batch.""" | |
def __init__(self, transforms): | |
""" | |
Args: | |
transforms: list of transforms to apply | |
""" | |
self.transforms = transforms | |
def append(self, transform): | |
self.transforms.append(transform) | |
def __call__(self, batch): | |
for t in self.transforms: | |
batch = t(batch) | |
return batch | |
class Mp3DecodeTransform: | |
def __init__( | |
self, | |
mp3_bytes_key="mp3_bytes", | |
audio_key="audio", | |
sample_rate=32000, | |
max_length=10.0, | |
min_length=None, | |
random_sample_crop=True, | |
allow_resample=True, | |
resampling_method="sinc_interp_kaiser", | |
keep_mp3_bytes=False, | |
debug_info_key=None, | |
): | |
"""Decode mp3 bytes to audio waveform | |
Args: | |
mp3_bytes_key (str, optional): The key to mp3 bytes in the input batch. Defaults to "mp3_bytes". | |
audio_key (str, optional): The key to save the decoded audio in the output batch. Defaults to "audio". | |
sample_rate (int, optional): The expected output audio_key. Defaults to 32000. | |
max_length (int, float, optional): the maximum output audio length in seconds if float, otherwise in samples. Defaults to 10. | |
min_length (int, optional): the minimum output audio length in seconds. Defaults to max_length. | |
random_sample_crop (bool, optional): Randomly crop the audio to max_length if its longer otherwise return the first crop. Defaults to True. | |
allow_resample (bool, optional): Resample the singal if the sampling rate don't match. Defaults to True. | |
resampling_method (str, optional): reampling method from torchaudio.transforms.Resample . Defaults to "sinc_interp_kaiser". | |
keep_mp3_bytes (bool, optional): keep the original bytes in the output dict. Defaults to False. | |
Raises: | |
Exception: if minimp3py is not installed | |
""" | |
self.mp3_bytes_key = mp3_bytes_key | |
self.audio_key = audio_key | |
self.sample_rate = sample_rate | |
self.max_length = max_length | |
if min_length is None: | |
min_length = max_length | |
self.min_length = min_length | |
self.random_sample_crop = random_sample_crop | |
self.allow_resample = allow_resample | |
self.resampling_method = resampling_method | |
self.keep_mp3_bytes = keep_mp3_bytes | |
self.debug_info_key = debug_info_key | |
self.resamplers_cache = {} | |
try: | |
import minimp3py # noqa: F401 | |
except: | |
raise Exception( | |
"minimp3py is not installed, please install it using: `CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip`" | |
) | |
def __call__(self, batch): | |
import minimp3py | |
data_list = batch[self.mp3_bytes_key] | |
if self.debug_info_key is not None: | |
file_name_list = batch[self.debug_info_key] | |
else: | |
file_name_list = range(len(data_list)) | |
audio_list = [] | |
for data, file_name in zip(data_list, file_name_list): | |
try: | |
duration, ch, sr = minimp3py.probe(data) | |
if isinstance(self.max_length, float): | |
max_length = int(self.max_length * sr) | |
else: | |
max_length = int(self.max_length * sr // self.sample_rate) | |
offset = 0 | |
if self.random_sample_crop and duration > max_length: | |
max_offset = max(int(duration - max_length), 0) + 1 | |
offset = torch.randint(max_offset, (1,)).item() | |
waveform, _ = minimp3py.read(data, start=offset, length=max_length) | |
waveform = waveform[:, 0] # 0 for the first channel only | |
if waveform.dtype != "float32": | |
raise RuntimeError("Unexpected wave type") | |
waveform = torch.from_numpy(waveform) | |
if len(waveform) == 0: | |
logger.warning( | |
f"Empty waveform for {file_name}, duration {duration}, offset {offset}, max_length {max_length}, sr {sr}, ch {ch}" | |
) | |
elif sr != self.sample_rate: | |
assert self.allow_resample, f"Unexpected sample rate {sr} instead of {self.sample_rate} at {file_name}" | |
if self.resamplers_cache.get(sr) is None: | |
self.resamplers_cache[sr] = torchaudio.transforms.Resample( | |
sr, | |
self.sample_rate, | |
resampling_method=self.resampling_method, | |
) | |
waveform = self.resamplers_cache[sr](waveform) | |
min_length = self.min_length | |
if isinstance(self.min_length, float): | |
min_length = int(self.min_length * self.sample_rate) | |
if min_length is not None and len(waveform) < min_length: | |
waveform = torch.concatenate( | |
( | |
waveform, | |
torch.zeros( | |
min_length - len(waveform), | |
dtype=waveform.dtype, | |
device=waveform.device, | |
), | |
), | |
dim=0, | |
) | |
audio_list.append(waveform) | |
except Exception as e: | |
print(f"Error decoding {file_name}: {e}") | |
raise e | |
batch[self.audio_key] = audio_list | |
batch["sampling_rate"] = [self.sample_rate] * len(audio_list) | |
if not self.keep_mp3_bytes: | |
del batch[self.mp3_bytes_key] | |
return batch | |