Spaces:
Build error
Build error
# ---------------------------------------------------------------------------- | |
# VatLM: Visual-Audio-Text Pre-Training with Unified Masked Prediction for Speech Representation Learning | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/VATLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq and av_hubert: https://github.com/facebookresearch/av_hubert | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
import itertools | |
import logging | |
import os | |
import sys | |
import time | |
from typing import Any, List, Optional, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from fairseq.data import data_utils | |
from fairseq.data.fairseq_dataset import FairseqDataset | |
DBG=True if len(sys.argv) == 1 else False | |
if DBG: | |
import utils as custom_utils | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=os.environ.get("LOGLEVEL", "DEBUG").upper(), | |
stream=sys.stdout, | |
) | |
else: | |
from . import utils as custom_utils | |
logger = logging.getLogger(__name__) | |
def load_text(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1): | |
n_long, n_short, n_unaligned = 0, 0, 0 | |
names, inds, sizes = [], [], [] | |
dur_from_label_list = [] | |
with open(manifest_path) as f: | |
for ind, line in enumerate(f): | |
items = line.strip().split("\t") | |
frames = items[0] | |
sz = int(frames) | |
if min_keep is not None and sz < min_keep: | |
n_short += 1 | |
elif max_keep is not None and sz > max_keep: | |
n_long += 1 | |
else: | |
inds.append(ind) | |
sizes.append(sz) | |
logger.info( | |
( | |
f"max_keep={max_keep}, min_keep={min_keep}, " | |
f"loaded {len(inds)}, skipped {n_short} short and {n_long} long" | |
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" | |
) | |
) | |
return inds, sizes | |
def load_label(label_path, inds): | |
with open(label_path) as f: | |
labels = [line.rstrip() for line in f] | |
labels = [labels[i] for i in inds] | |
return labels | |
def load_phone_label(tsv, inds): | |
with open(tsv) as f: | |
labels = [line.rstrip() for line in f.readlines()] | |
labels = [labels[i] for i in inds] | |
return labels | |
def load_label_offset(label_path, inds): | |
with open(label_path) as f: | |
code_lengths = [len(line.encode("utf-8")) for line in f] | |
offsets = list(itertools.accumulate([0] + code_lengths)) | |
offsets = [(offsets[i], offsets[i + 1]) for i in inds] | |
return offsets | |
class TextHubertDataset(FairseqDataset): | |
def __init__( | |
self, | |
manifest_path: str, | |
sample_rate: float, | |
label_paths: List[str], | |
label_rates: Union[List[float], float], # -1 for sequence labels | |
pad_list: List[str], | |
eos_list: List[str], | |
label_processors: Optional[List[Any]] = None, | |
phone_sequence_processors: Optional[List[Any]] = None, | |
max_keep_sample_size: Optional[int] = None, | |
min_keep_sample_size: Optional[int] = None, | |
max_sample_size: Optional[int] = None, | |
shuffle: bool = True, | |
pad_audio: bool = False, | |
normalize: bool = False, | |
store_labels: bool = True, | |
single_target: bool = False, | |
stack_order_audio: int=1, | |
skip_verify: bool=False, | |
is_s2s=False, | |
): | |
self.label_rates = ( | |
[label_rates for _ in range(len(label_paths))] | |
if isinstance(label_rates, int) | |
else label_rates | |
) | |
inds, self.sizes = load_text(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates) | |
self.sample_rate = sample_rate | |
self.stack_order_audio = stack_order_audio | |
self.shuffle = shuffle | |
self.num_labels = len(label_paths) | |
self.pad_list = pad_list | |
self.eos_list = eos_list | |
self.label_processors = label_processors | |
self.phone_processors = phone_sequence_processors | |
self.single_target = single_target | |
self.store_labels = store_labels | |
self.is_s2s = is_s2s | |
if store_labels: | |
self.label_list = [load_label(p, inds) for p in label_paths] | |
self.phone_list = [load_phone_label(p, inds) for p in [manifest_path]] | |
else: | |
self.label_paths = label_paths | |
self.label_offsets_list = [ | |
load_label_offset(p, inds) for p in label_paths | |
] | |
self.max_sample_size = ( | |
max_sample_size if max_sample_size is not None else sys.maxsize | |
) | |
self.pad_audio = pad_audio | |
self.normalize = normalize | |
def get_label(self, index, label_idx): | |
if self.store_labels: | |
label = self.label_list[label_idx][index] | |
else: | |
with open(self.label_paths[label_idx]) as f: | |
offset_s, offset_e = self.label_offsets_list[label_idx][index] | |
f.seek(offset_s) | |
label = f.read(offset_e - offset_s) | |
if self.label_processors is not None: | |
label = self.label_processors[label_idx](label) | |
return label | |
def get_labels(self, index): | |
return [self.get_label(index, i) for i in range(self.num_labels)] | |
def get_phone(self, index, label_idx): | |
label = self.phone_list[label_idx][index] | |
if self.phone_processors is not None: | |
label = self.phone_processors[label_idx](label) | |
return label | |
def get_phones(self, index): | |
return [self.get_phone(index, i) for i in range(1)] | |
def __getitem__(self, index): | |
labels = self.get_labels(index) | |
phone_sequence_list = self.get_phones(index) | |
return {"id": index, "label_list": labels, "phone_sequence_list": phone_sequence_list} | |
def __len__(self): | |
return len(self.sizes) | |
def collater(self, samples): | |
samples = [s for s in samples if s["id"] is not None] | |
if len(samples) == 0: | |
return {} | |
targets_by_label = [ | |
[s["label_list"][i] for s in samples] | |
for i in range(self.num_labels) | |
] | |
targets_list, lengths_list, ntokens_list = self.collater_label( | |
targets_by_label, | |
) | |
phone_sequence_list = [s["phone_sequence_list"] for s in samples] | |
if phone_sequence_list[0] is None: | |
phone_sequence_list = None | |
targets_by_phone_label = [ | |
[s["phone_sequence_list"][i] for s in samples] | |
for i in range(self.num_labels) | |
] | |
targets_phone_list, lengths_phone_list, ntokens_phone_list = self.collater_phone_label( | |
targets_by_phone_label, | |
) | |
net_input = {"source": None} | |
batch = { | |
"id": torch.LongTensor([s["id"] for s in samples]), | |
"net_input": net_input, | |
} | |
if self.single_target: | |
batch["target_lengths"] = lengths_list[0] | |
batch["ntokens"] = ntokens_list[0] | |
if self.is_s2s: | |
batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1] | |
else: | |
batch["target"] = targets_list[0] | |
else: | |
batch["target_lengths_list"] = lengths_list | |
batch["ntokens_list"] = ntokens_list | |
batch["target_list"] = targets_list | |
batch["extra_text_phone_list"] = targets_phone_list | |
return batch | |
def collater_frm_label( | |
self, targets, label_rate, pad | |
): | |
lengths = torch.LongTensor([len(t) for t in targets]) | |
ntokens = lengths.sum().item() | |
targets = data_utils.collate_tokens( | |
targets, pad_idx=pad, left_pad=False | |
) | |
return targets, lengths, ntokens | |
def collater_frm_phone_label( | |
self, targets, pad | |
): | |
lengths = torch.LongTensor([len(t) for t in targets]) | |
ntokens = lengths.sum().item() | |
targets = data_utils.collate_tokens( | |
targets, pad_idx=pad, left_pad=False | |
) | |
return targets, lengths, ntokens | |
def collater_label(self, targets_by_label,): | |
targets_list, lengths_list, ntokens_list = [], [], [] | |
itr = zip(targets_by_label, self.label_rates, self.pad_list) | |
for targets, label_rate, pad in itr: | |
targets, lengths, ntokens = self.collater_frm_label( | |
targets, label_rate, pad | |
) | |
targets_list.append(targets) | |
lengths_list.append(lengths) | |
ntokens_list.append(ntokens) | |
return targets_list, lengths_list, ntokens_list | |
def collater_phone_label(self, targets_by_label): | |
targets_list, lengths_list, ntokens_list = [], [], [] | |
itr = zip(targets_by_label, self.label_rates, self.pad_list) | |
for targets, label_rate, pad in itr: | |
targets, lengths, ntokens = self.collater_frm_phone_label( | |
targets, pad | |
) | |
targets_list.append(targets) | |
lengths_list.append(lengths) | |
ntokens_list.append(ntokens) | |
return targets_list, lengths_list, ntokens_list | |
def num_tokens(self, index): | |
return self.size(index) | |
def size(self, index): | |
if self.pad_audio: | |
return self.sizes[index] | |
return min(self.sizes[index], self.max_sample_size) | |
def ordered_indices(self): | |
if self.shuffle: | |
order = [np.random.permutation(len(self))] | |
else: | |
order = [np.arange(len(self))] | |
order.append(self.sizes) | |
return np.lexsort(order)[::-1] | |