OpenSound's picture
Upload 518 files
dd9600d verified
# @ hwang258@jhu.edu
import os
import json
import torch
import random
import logging
import shutil
import typing as tp
import numpy as np
import torchaudio
import sys
from torch.utils.data import Dataset
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
def read_json(path):
with open(path, 'r') as f:
return json.load(f)
class CapSpeech(Dataset):
def __init__(
self,
dataset_dir: str = None,
clap_emb_dir: str = None,
t5_folder_name: str = "t5",
phn_folder_name: str = "g2p",
manifest_name: str = "manifest",
json_name: str = "jsons",
dynamic_batching: bool = True,
text_pad_token: int = -1,
audio_pad_token: float = 0.0,
split: str = "val",
sr: int = 24000,
norm_audio: bool = False,
vocab_file: str = None,
):
super().__init__()
self.dataset_dir = dataset_dir
self.clap_emb_dir = clap_emb_dir
self.t5_folder_name = t5_folder_name
self.phn_folder_name = phn_folder_name
self.manifest_name = manifest_name
self.json_name = json_name
self.dynamic_batching = dynamic_batching
self.text_pad_token = text_pad_token
self.audio_pad_token = torch.tensor(audio_pad_token)
self.split = split
self.sr = sr
self.norm_audio = norm_audio
assert self.split in ['train', 'train_small', 'val', 'test']
manifest_fn = os.path.join(self.dataset_dir, self.manifest_name, self.split+".txt")
meta = read_json(os.path.join(self.dataset_dir, self.json_name, self.split + ".json"))
self.meta = {item["segment_id"]: item["audio_path"] for item in meta}
with open(manifest_fn, "r") as rf:
data = [l.strip().split("\t") for l in rf.readlines()]
# data = [item for item in data if item[2] == 'none'] # remove sound effects
self.data = [item[0] for item in data]
self.tag_list = [item[1] for item in data]
logging.info(f"number of data points for {self.split} split: {len(self.data)}")
# phoneme vocabulary
if vocab_file is None:
vocab_fn = os.path.join(self.dataset_dir, "vocab.txt")
else:
vocab_fn = vocab_file
with open(vocab_fn, "r") as f:
temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
self.phn2num = {item[1]:int(item[0]) for item in temp}
def __len__(self):
return len(self.data)
def _load_audio(self, audio_path):
try:
y, sr = torchaudio.load(audio_path)
if y.shape[0] > 1:
y = y.mean(dim=0, keepdim=True)
if sr != self.sr:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)
y = resampler(y)
if self.norm_audio:
eps = 1e-9
max_val = torch.max(torch.abs(y))
y = y / (max_val + eps)
if torch.isnan(y.mean()):
return None
return y
except:
return None
def _load_phn_enc(self, index):
try:
seg_id = self.data[index]
pf = os.path.join(self.dataset_dir, self.phn_folder_name, seg_id+".txt")
audio_path = self.meta[seg_id]
cf = os.path.join(self.dataset_dir, self.t5_folder_name, seg_id+".npz")
tagf = os.path.join(self.clap_emb_dir, self.tag_list[index]+'.npz')
with open(pf, "r") as p:
phns = [l.strip() for l in p.readlines()]
assert len(phns) == 1, phns
x = [self.phn2num[item] for item in phns[0].split(" ")]
c = np.load(cf)['arr_0']
c = torch.tensor(c).squeeze()
tag = np.load(tagf)['arr_0']
tag = torch.tensor(tag).squeeze()
y = self._load_audio(audio_path)
if y is not None:
return x, y, c, tag
return None, None, None, None
except:
return None, None, None, None
def __getitem__(self, index):
x, y, c, tag = self._load_phn_enc(index)
if x is None:
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
"c": None,
"c_len": None,
"tag": None
}
x_len, y_len, c_len = len(x), len(y[0]), len(c)
y_len = y_len / self.sr
if y_len * self.sr / 256 <= x_len:
return {
"x": None,
"x_len": None,
"y": None,
"y_len": None,
"c": None,
"c_len": None,
"tag": None
}
x = torch.LongTensor(x)
return {
"x": x,
"x_len": x_len,
"y": y,
"y_len": y_len,
"c": c,
"c_len": c_len,
"tag": tag
}
def collate(self, batch):
out = {key:[] for key in batch[0]}
for item in batch:
if item['x'] == None: # deal with load failure
continue
if item['c'].ndim != 2:
continue
for key, val in item.items():
out[key].append(val)
res = {}
res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.text_pad_token)
res["x_lens"] = torch.LongTensor(out["x_len"])
if self.dynamic_batching:
res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.audio_pad_token)
res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
else:
res['y'] = torch.stack(out['y'], dim=0)
res["y_lens"] = torch.Tensor(out["y_len"])
res['c'] = torch.nn.utils.rnn.pad_sequence(out['c'], batch_first=True)
res["c_lens"] = torch.LongTensor(out["c_len"])
res["tag"] = torch.stack(out['tag'], dim=0)
return res
if __name__ == "__main__":
# debug
import argparse
from torch.utils.data import DataLoader
from accelerate import Accelerator
dataset = CapSpeech(
dataset_dir="./data/capspeech",
clap_emb_dir="./data/clap_embs/",
split="val"
)