Spaces:
Sleeping
Sleeping
File size: 8,298 Bytes
9b0d6c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import os
import datasets
import h5py
import numpy as np
import pandas as pd
import torch
import torchaudio
from data_util.audioset_classes import as_strong_train_classes
## Transforms with a similar style to https://github.com/descriptinc/audiotools/blob/master/audiotools/data/transforms.py
logger = datasets.logging.get_logger(__name__)
def target_transform(sample):
del sample["labels"]
del sample["label_ids"]
return sample
def strong_label_transform(sample, strong_label_encoder=None):
assert strong_label_encoder is not None
events = pd.DataFrame(sample['events'][0])
events = events[events['event_label'].isin(set(as_strong_train_classes))]
strong = strong_label_encoder.encode_strong_df(events).T
sample["strong"] = [strong]
sample["event_count"] = [strong.sum(1)]
# encode ground truth events as string - we will use this for evaluation
sample["gt_string"] = ["++".join([";;".join([str(e[0]), str(e[1]), e[2]]) for e in
zip(sample['events'][0]['onset'], sample['events'][0]['offset'],
sample['events'][0]['event_label'])])]
del sample['events']
return sample
class AddPseudoLabelsTransform:
def __init__(self, pseudo_labels_file):
self.pseudo_labels_file = pseudo_labels_file
if self.pseudo_labels_file is not None:
# fetch dict of positions for each example
self.ex2pseudo_idx = {}
f = h5py.File(self.pseudo_labels_file, "r")
for i, fname in enumerate(f["filenames"]):
self.ex2pseudo_idx[fname.decode("UTF-8")] = i
self._opened_pseudo_hdf5 = None
@property
def pseudo_hdf5_file(self):
if self._opened_pseudo_hdf5 is None:
self._opened_pseudo_hdf5 = h5py.File(self.pseudo_labels_file, "r")
return self._opened_pseudo_hdf5
def add_pseudo_label_transform(self, sample):
indices = [self.ex2pseudo_idx[fn.rstrip(".mp3")] for fn in sample['filename']]
pseudo_strong = [torch.from_numpy(np.stack(self.pseudo_hdf5_file["strong_logits"][index])).float()
for index in indices]
pseudo_strong = [torch.sigmoid(pseudo_strong[i]) for i in range(len(pseudo_strong))]
sample['pseudo_strong'] = pseudo_strong
return sample
class SequentialTransform:
"""Apply a sequence of transforms to a batch."""
def __init__(self, transforms):
"""
Args:
transforms: list of transforms to apply
"""
self.transforms = transforms
def append(self, transform):
self.transforms.append(transform)
def __call__(self, batch):
for t in self.transforms:
batch = t(batch)
return batch
class Mp3DecodeTransform:
def __init__(
self,
mp3_bytes_key="mp3_bytes",
audio_key="audio",
sample_rate=32000,
max_length=10.0,
min_length=None,
random_sample_crop=True,
allow_resample=True,
resampling_method="sinc_interp_kaiser",
keep_mp3_bytes=False,
debug_info_key=None,
):
"""Decode mp3 bytes to audio waveform
Args:
mp3_bytes_key (str, optional): The key to mp3 bytes in the input batch. Defaults to "mp3_bytes".
audio_key (str, optional): The key to save the decoded audio in the output batch. Defaults to "audio".
sample_rate (int, optional): The expected output audio_key. Defaults to 32000.
max_length (int, float, optional): the maximum output audio length in seconds if float, otherwise in samples. Defaults to 10.
min_length (int, optional): the minimum output audio length in seconds. Defaults to max_length.
random_sample_crop (bool, optional): Randomly crop the audio to max_length if its longer otherwise return the first crop. Defaults to True.
allow_resample (bool, optional): Resample the singal if the sampling rate don't match. Defaults to True.
resampling_method (str, optional): reampling method from torchaudio.transforms.Resample . Defaults to "sinc_interp_kaiser".
keep_mp3_bytes (bool, optional): keep the original bytes in the output dict. Defaults to False.
Raises:
Exception: if minimp3py is not installed
"""
self.mp3_bytes_key = mp3_bytes_key
self.audio_key = audio_key
self.sample_rate = sample_rate
self.max_length = max_length
if min_length is None:
min_length = max_length
self.min_length = min_length
self.random_sample_crop = random_sample_crop
self.allow_resample = allow_resample
self.resampling_method = resampling_method
self.keep_mp3_bytes = keep_mp3_bytes
self.debug_info_key = debug_info_key
self.resamplers_cache = {}
try:
import minimp3py # noqa: F401
except:
raise Exception(
"minimp3py is not installed, please install it using: `CFLAGS='-O3 -march=native' pip install https://github.com/f0k/minimp3py/archive/master.zip`"
)
def __call__(self, batch):
import minimp3py
data_list = batch[self.mp3_bytes_key]
if self.debug_info_key is not None:
file_name_list = batch[self.debug_info_key]
else:
file_name_list = range(len(data_list))
audio_list = []
for data, file_name in zip(data_list, file_name_list):
try:
duration, ch, sr = minimp3py.probe(data)
if isinstance(self.max_length, float):
max_length = int(self.max_length * sr)
else:
max_length = int(self.max_length * sr // self.sample_rate)
offset = 0
if self.random_sample_crop and duration > max_length:
max_offset = max(int(duration - max_length), 0) + 1
offset = torch.randint(max_offset, (1,)).item()
waveform, _ = minimp3py.read(data, start=offset, length=max_length)
waveform = waveform[:, 0] # 0 for the first channel only
if waveform.dtype != "float32":
raise RuntimeError("Unexpected wave type")
waveform = torch.from_numpy(waveform)
if len(waveform) == 0:
logger.warning(
f"Empty waveform for {file_name}, duration {duration}, offset {offset}, max_length {max_length}, sr {sr}, ch {ch}"
)
elif sr != self.sample_rate:
assert self.allow_resample, f"Unexpected sample rate {sr} instead of {self.sample_rate} at {file_name}"
if self.resamplers_cache.get(sr) is None:
self.resamplers_cache[sr] = torchaudio.transforms.Resample(
sr,
self.sample_rate,
resampling_method=self.resampling_method,
)
waveform = self.resamplers_cache[sr](waveform)
min_length = self.min_length
if isinstance(self.min_length, float):
min_length = int(self.min_length * self.sample_rate)
if min_length is not None and len(waveform) < min_length:
waveform = torch.concatenate(
(
waveform,
torch.zeros(
min_length - len(waveform),
dtype=waveform.dtype,
device=waveform.device,
),
),
dim=0,
)
audio_list.append(waveform)
except Exception as e:
print(f"Error decoding {file_name}: {e}")
raise e
batch[self.audio_key] = audio_list
batch["sampling_rate"] = [self.sample_rate] * len(audio_list)
if not self.keep_mp3_bytes:
del batch[self.mp3_bytes_key]
return batch
|