Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import argparse | |
from pathlib import Path | |
from typing import List, Dict, Set | |
from tqdm import tqdm | |
import soundfile as sf | |
from datasets import load_dataset | |
import logging | |
import os | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser(description="Prepare the CapSpeech dataset") | |
parser.add_argument('--hub', type=str, required=True, help='Huggingface repo') | |
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save the JSON files') | |
parser.add_argument('--cache_dir', type=str, required=True, help='Cache directory for datasets') | |
parser.add_argument('--libriR_wav_dir', type=str, required=True, help='Directories containing WAV files') | |
parser.add_argument('--other_wav_dir', type=str, required=True, help='Directories containing WAV files') | |
parser.add_argument('--audio_min_length', type=float, default=3.0, help='Minimum audio duration in seconds') | |
parser.add_argument('--audio_max_length', type=float, default=18.0, help='Maximum audio duration in seconds') | |
parser.add_argument('--splits', type=str, nargs='+', | |
default=['train', 'val'], | |
help='List of splits to process') | |
parser.add_argument('--debug', action='store_true', help='Enable debug mode with limited data processing') | |
return parser.parse_args() | |
def setup_logging() -> None: | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s [%(levelname)s] %(message)s', | |
handlers=[ | |
logging.StreamHandler() | |
] | |
) | |
def process_dataset_split(split, dataset_split, args) -> List[Dict]: | |
""" | |
Process a single dataset split and extract relevant records. | |
Args: | |
split: The name of the split (e.g., 'train'). | |
dataset_split: The dataset split object. | |
args: Parsed command-line arguments. | |
Returns: | |
A list of dictionaries containing the processed records. | |
""" | |
logging.info(f"Processing split: {split}") | |
filelist: List[Dict] = [] | |
total_duration: float = 0.0 | |
num_samples: int = len(dataset_split) if not args.debug else 500 | |
source_path = { | |
'libritts-r': args.libriR_wav_dir, | |
'voxceleb': args.other_wav_dir, | |
'expresso': args.other_wav_dir, | |
'ears': args.other_wav_dir, | |
'vctk': args.other_wav_dir, | |
} | |
for idx in tqdm(range(num_samples), desc=f"Processing {split}"): | |
try: | |
data = dataset_split[idx] | |
except IndexError: | |
logging.warning(f"Index {idx} out of range for split '{split}'. Skipping.") | |
continue | |
audio_path: str = data.get("audio_path", "") | |
duration: float = data.get("speech_duration", 0.0) | |
source: str = data.get("source", "") | |
audio_path = os.path.join(source_path[source], audio_path) | |
if not audio_path: | |
logging.warning(f"Missing audio_path at index {idx} in split '{split}'. Skipping.") | |
continue | |
if not os.path.exists(audio_path): | |
logging.warning(f"WAV file does not exist: {audio_path}") | |
continue | |
if not (args.audio_min_length <= duration <= args.audio_max_length): | |
continue | |
record: Dict = { | |
"segment_id": audio_path.split('/')[-1].split('.')[0], | |
"audio_path": audio_path, | |
"text": data.get('text', ''), | |
"caption": data.get('caption', ''), | |
"duration": duration, | |
"source": source | |
} | |
filelist.append(record) | |
total_duration += duration | |
logging.info(f"Total duration for split '{split}': {total_duration / 3600:.2f} hrs.") | |
logging.info(f"Total records for split '{split}': {len(filelist)}") | |
return filelist | |
def save_json(filelist: List[Dict], output_path: Path) -> None: | |
""" | |
Save the list of records to a JSON file. | |
Args: | |
filelist: List of dictionaries containing the records. | |
output_path: Path to the output JSON file. | |
""" | |
try: | |
with output_path.open('w', encoding='utf-8') as json_file: | |
json.dump(filelist, json_file, ensure_ascii=False, indent=4) | |
logging.info(f"Saved {len(filelist)} records to '{output_path}'") | |
except Exception as e: | |
logging.error(f"Failed to save JSON to '{output_path}': {e}") | |
def main() -> None: | |
args = parse_args() | |
setup_logging() | |
save_dir: Path = Path(args.save_dir) | |
jsons_dir: Path = save_dir / 'jsons' | |
jsons_dir.mkdir(parents=True, exist_ok=True) | |
logging.info(f"JSON files will be saved to '{jsons_dir}'") | |
logging.info("Loading dataset...") | |
try: | |
ds = load_dataset(args.hub) | |
# ds = load_dataset(args.hub, cache_dir=args.cache_dir) | |
except Exception as e: | |
logging.error(f"Failed to load dataset: {e}") | |
return | |
splits_to_process = args.splits | |
available_splits = set(ds.keys()) | |
selected_splits = [split for split in splits_to_process if split in available_splits] | |
missing_splits = set(splits_to_process) - available_splits | |
if missing_splits: | |
logging.warning(f"The following splits were not found in the dataset and will be skipped: {missing_splits}") | |
for split in selected_splits: | |
dataset_split = ds[split] | |
filelist = process_dataset_split(split, dataset_split, args) | |
output_file: Path = jsons_dir / f"{split}.json" | |
save_json(filelist, output_file) | |
if __name__ == "__main__": | |
main() | |