import pandas as pd import os import random import ast import numpy as np import torch from einops import repeat, rearrange import librosa from torch.utils.data import Dataset import torchaudio class DreamData(Dataset): def __init__(self, data_dir, meta_dir, subset, prompt_dir,): self.datadir = data_dir meta = pd.read_csv(meta_dir) self.meta = meta[meta['subset'] == subset] self.subset = subset self.prompts = pd.read_csv(prompt_dir) def __getitem__(self, index): row = self.meta.iloc[index] # get spk spk_path = self.datadir + row['spk_path'] spk = torch.load(spk_path, map_location='cpu').squeeze(0) speaker = row['speaker'] # get prompt prompt = self.prompts[self.prompts['speaker_id'] == str(speaker)].sample(1)['prompt'].iloc[0] return spk, prompt def __len__(self): return len(self.meta)