Spaces:
Sleeping
Sleeping
import json | |
from pathlib import Path | |
from typing import Dict, List, Tuple | |
import numpy as np | |
import pandas as pd | |
import soundfile as sf | |
import torch | |
from intervaltree import IntervalTree | |
from torch.utils.data import Dataset | |
class FixCropDataset(Dataset): | |
""" | |
Read in a JSON file and return audio and audio filenames | |
""" | |
def __init__(self, data: Dict, | |
audio_dir: Path, | |
sample_rate: int, | |
label_fps: int, | |
label_to_idx: Dict, | |
nlabels: int): | |
self.clip_len = 120 | |
self.target_len = 10 | |
self.pieces_per_clip = self.clip_len // self.target_len | |
self.filenames = list(data.keys()) | |
self.audio_dir = audio_dir | |
assert self.audio_dir.is_dir(), f"{audio_dir} is not a directory" | |
self.sample_rate = sample_rate | |
# all files are 120 seconds long, split them into 12 x 10 second pieces | |
self.pieces = [] | |
self.labels = [] | |
self.timestamps = [] | |
for filename in self.filenames: | |
self.pieces += [(filename, i) for i in range(self.pieces_per_clip)] | |
labels = data[filename] | |
frame_len = 1000 / label_fps | |
timestamps = np.arange(label_fps * self.clip_len) * frame_len + 0.5 * frame_len | |
timestamp_labels = get_labels_for_timestamps(labels, timestamps) | |
ys = [] | |
for timestamp_label in timestamp_labels: | |
timestamp_label_idxs = [label_to_idx[str(event)] for event in timestamp_label] | |
y_timestamp = label_to_binary_vector(timestamp_label_idxs, nlabels) | |
ys.append(y_timestamp) | |
ys = torch.stack(ys) | |
frames_per_clip = ys.size(0) // self.pieces_per_clip | |
self.labels += [ys[frames_per_clip * i: frames_per_clip * (i + 1)] for i in range(self.pieces_per_clip)] | |
self.timestamps += [timestamps[frames_per_clip * i: frames_per_clip * (i + 1)] for i in | |
range(self.pieces_per_clip)] | |
assert len(self.labels) == len(self.pieces) == len(self.filenames) * self.pieces_per_clip | |
def __len__(self): | |
return len(self.pieces) | |
def __getitem__(self, idx): | |
filename = self.pieces[idx][0] | |
piece = self.pieces[idx][1] | |
audio_path = self.audio_dir.joinpath(filename) | |
audio, sr = sf.read(str(audio_path), dtype=np.float32) | |
assert sr == self.sample_rate | |
start = self.sample_rate * piece * self.target_len | |
end = start + self.sample_rate * self.target_len | |
audio = audio[start:end] | |
return audio, self.labels[idx].transpose(0, 1), filename, self.timestamps[idx] | |
class RandomCropDataset(Dataset): | |
""" | |
Read in a JSON file and return audio and audio filenames | |
""" | |
def __init__(self, data: Dict, | |
audio_dir: Path, | |
sample_rate: int, | |
label_fps: int, | |
label_to_idx: Dict, | |
nlabels: int): | |
self.clip_len = 120 | |
self.target_len = 10 | |
self.pieces_per_clip = self.clip_len // self.target_len | |
self.filenames = list(data.keys()) | |
self.audio_dir = audio_dir | |
assert self.audio_dir.is_dir(), f"{audio_dir} is not a directory" | |
self.sample_rate = sample_rate | |
self.label_fps = label_fps | |
# all files are 120 seconds long, randomly crop 10 seconds snippets | |
self.labels = [] | |
self.timestamps = [] | |
for filename in self.filenames: | |
labels = data[filename] | |
frame_len = 1000 / label_fps | |
timestamps = np.arange(label_fps * self.clip_len) * frame_len + 0.5 * frame_len | |
timestamp_labels = get_labels_for_timestamps(labels, timestamps) | |
ys = [] | |
for timestamp_label in timestamp_labels: | |
timestamp_label_idxs = [label_to_idx[str(event)] for event in timestamp_label] | |
y_timestamp = label_to_binary_vector(timestamp_label_idxs, nlabels) | |
ys.append(y_timestamp) | |
ys = torch.stack(ys) | |
self.labels.append(ys) | |
self.timestamps.append(timestamps) | |
assert len(self.labels) == len(self.filenames) | |
def __len__(self): | |
return len(self.filenames) * self.clip_len // self.target_len | |
def __getitem__(self, idx): | |
idx = idx % len(self.filenames) | |
filename = self.filenames[idx] | |
audio_path = self.audio_dir.joinpath(filename) | |
audio, sr = sf.read(str(audio_path), dtype=np.float32) | |
assert sr == self.sample_rate | |
# crop random 10 seconds piece | |
labels_to_pick = self.target_len * self.label_fps | |
max_offset = len(self.labels[idx]) - labels_to_pick + 1 | |
offset = torch.randint(max_offset, (1,)).item() | |
labels = self.labels[idx][offset:offset + labels_to_pick] | |
scale = self.sample_rate // self.label_fps | |
audio = audio[offset * scale:offset * scale + labels_to_pick * scale] | |
timestamps = self.timestamps[idx][offset:offset + labels_to_pick] | |
return audio, labels.transpose(0, 1), filename, timestamps | |
def get_training_dataset( | |
task_path, | |
sample_rate=16000, | |
label_fps=25, | |
wavmix_p=0.0, | |
random_crop=True | |
): | |
task_path = Path(task_path) | |
label_vocab, nlabels = label_vocab_nlabels(task_path) | |
label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx") | |
train_fold = task_path.joinpath("train.json") | |
audio_dir = task_path.joinpath(str(sample_rate), "train") | |
train_fold_data = json.load(train_fold.open()) | |
if random_crop: | |
dataset = RandomCropDataset(train_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels) | |
else: | |
dataset = FixCropDataset(train_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels) | |
if wavmix_p > 0: | |
dataset = MixupDataset(dataset, rate=wavmix_p) | |
return dataset | |
def get_validation_dataset( | |
task_path, | |
sample_rate=16000, | |
label_fps=25, | |
): | |
task_path = Path(task_path) | |
label_vocab, nlabels = label_vocab_nlabels(task_path) | |
label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx") | |
valid_fold = task_path.joinpath("valid.json") | |
audio_dir = task_path.joinpath(str(sample_rate), "valid") | |
valid_fold_data = json.load(valid_fold.open()) | |
dataset = FixCropDataset(valid_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels) | |
return dataset | |
def get_test_dataset( | |
task_path, | |
sample_rate=16000, | |
label_fps=25, | |
): | |
task_path = Path(task_path) | |
label_vocab, nlabels = label_vocab_nlabels(task_path) | |
label_to_idx = label_vocab_as_dict(label_vocab, key="label", value="idx") | |
test_fold = task_path.joinpath("test.json") | |
audio_dir = task_path.joinpath(str(sample_rate), "test") | |
test_fold_data = json.load(test_fold.open()) | |
dataset = FixCropDataset(test_fold_data, audio_dir, sample_rate, label_fps, label_to_idx, nlabels) | |
return dataset | |
def get_labels_for_timestamps(labels: List, timestamps: np.ndarray) -> List: | |
# A list of labels present at each timestamp | |
tree = IntervalTree() | |
# Add all events to the label tree | |
for event in labels: | |
# We add 0.0001 so that the end also includes the event | |
tree.addi(event["start"], event["end"] + 0.0001, event["label"]) | |
timestamp_labels = [] | |
# Update the binary vector of labels with intervals for each timestamp | |
for j, t in enumerate(timestamps): | |
interval_labels: List[str] = [interval.data for interval in tree[t]] | |
timestamp_labels.append(interval_labels) | |
# If we want to store the timestamp too | |
# labels_for_sound.append([float(t), interval_labels]) | |
assert len(timestamp_labels) == len(timestamps) | |
return timestamp_labels | |
def label_vocab_nlabels(task_path: Path) -> Tuple[pd.DataFrame, int]: | |
label_vocab = pd.read_csv(task_path.joinpath("labelvocabulary.csv")) | |
nlabels = len(label_vocab) | |
assert nlabels == label_vocab["idx"].max() + 1 | |
return (label_vocab, nlabels) | |
def label_vocab_as_dict(df: pd.DataFrame, key: str, value: str) -> Dict: | |
""" | |
Returns a dictionary of the label vocabulary mapping the label column to | |
the idx column. key sets whether the label or idx is the key in the dict. The | |
other column will be the value. | |
""" | |
if key == "label": | |
# Make sure the key is a string | |
df["label"] = df["label"].astype(str) | |
value = "idx" | |
else: | |
assert key == "idx", "key argument must be either 'label' or 'idx'" | |
value = "label" | |
return df.set_index(key).to_dict()[value] | |
def label_to_binary_vector(label: List, num_labels: int) -> torch.Tensor: | |
""" | |
Converts a list of labels into a binary vector | |
Args: | |
label: list of integer labels | |
num_labels: total number of labels | |
Returns: | |
A float Tensor that is multi-hot binary vector | |
""" | |
# Lame special case for multilabel with no labels | |
if len(label) == 0: | |
# BCEWithLogitsLoss wants float not long targets | |
binary_labels = torch.zeros((num_labels,), dtype=torch.float) | |
else: | |
binary_labels = torch.zeros((num_labels,)).scatter(0, torch.tensor(label), 1.0) | |
# Validate the binary vector we just created | |
assert set(torch.where(binary_labels == 1.0)[0].numpy()) == set(label) | |
return binary_labels | |
class MixupDataset(Dataset): | |
""" Mixing Up wave forms | |
""" | |
def __init__(self, dataset, beta=0.2, rate=0.5): | |
self.beta = beta | |
self.rate = rate | |
self.dataset = dataset | |
print(f"Mixing up waveforms from dataset of len {len(dataset)}") | |
def __getitem__(self, index): | |
if torch.rand(1) < self.rate: | |
batch1 = self.dataset[index] | |
idx2 = torch.randint(len(self.dataset), (1,)).item() | |
batch2 = self.dataset[idx2] | |
x1, x2 = batch1[0], batch2[0] | |
y1, y2 = batch1[1], batch2[1] | |
l = np.random.beta(self.beta, self.beta) | |
l = max(l, 1. - l) | |
x1 = x1 - x1.mean() | |
x2 = x2 - x2.mean() | |
x = (x1 * l + x2 * (1. - l)) | |
x = x - x.mean() | |
y = (y1 * l + y2 * (1. - l)) | |
return x, y, batch1[2], batch1[3] | |
return self.dataset[index] | |
def __len__(self): | |
return len(self.dataset) | |