amupd's picture
SpeechT5 upload
62e9ca6
# ----------------------------------------------------------------------------
# VatLM: Visual-Audio-Text Pre-Training with Unified Masked Prediction for Speech Representation Learning
# Github source: https://github.com/microsoft/SpeechT5/tree/main/VATLM
# Code based on fairseq: https://github.com/facebookresearch/fairseq and av_hubert: https://github.com/facebookresearch/av_hubert
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import itertools
import logging
import os
import sys
import time
from typing import Any, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
DBG=True if len(sys.argv) == 1 else False
if DBG:
import utils as custom_utils
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "DEBUG").upper(),
stream=sys.stdout,
)
else:
from . import utils as custom_utils
logger = logging.getLogger(__name__)
def load_text(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1):
n_long, n_short, n_unaligned = 0, 0, 0
names, inds, sizes = [], [], []
dur_from_label_list = []
with open(manifest_path) as f:
for ind, line in enumerate(f):
items = line.strip().split("\t")
frames = items[0]
sz = int(frames)
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
inds.append(ind)
sizes.append(sz)
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(inds)}, skipped {n_short} short and {n_long} long"
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
return inds, sizes
def load_label(label_path, inds):
with open(label_path) as f:
labels = [line.rstrip() for line in f]
labels = [labels[i] for i in inds]
return labels
def load_phone_label(tsv, inds):
with open(tsv) as f:
labels = [line.rstrip() for line in f.readlines()]
labels = [labels[i] for i in inds]
return labels
def load_label_offset(label_path, inds):
with open(label_path) as f:
code_lengths = [len(line.encode("utf-8")) for line in f]
offsets = list(itertools.accumulate([0] + code_lengths))
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
return offsets
class TextHubertDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
phone_sequence_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
single_target: bool = False,
stack_order_audio: int=1,
skip_verify: bool=False,
is_s2s=False,
):
self.label_rates = (
[label_rates for _ in range(len(label_paths))]
if isinstance(label_rates, int)
else label_rates
)
inds, self.sizes = load_text(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates)
self.sample_rate = sample_rate
self.stack_order_audio = stack_order_audio
self.shuffle = shuffle
self.num_labels = len(label_paths)
self.pad_list = pad_list
self.eos_list = eos_list
self.label_processors = label_processors
self.phone_processors = phone_sequence_processors
self.single_target = single_target
self.store_labels = store_labels
self.is_s2s = is_s2s
if store_labels:
self.label_list = [load_label(p, inds) for p in label_paths]
self.phone_list = [load_phone_label(p, inds) for p in [manifest_path]]
else:
self.label_paths = label_paths
self.label_offsets_list = [
load_label_offset(p, inds) for p in label_paths
]
self.max_sample_size = (
max_sample_size if max_sample_size is not None else sys.maxsize
)
self.pad_audio = pad_audio
self.normalize = normalize
def get_label(self, index, label_idx):
if self.store_labels:
label = self.label_list[label_idx][index]
else:
with open(self.label_paths[label_idx]) as f:
offset_s, offset_e = self.label_offsets_list[label_idx][index]
f.seek(offset_s)
label = f.read(offset_e - offset_s)
if self.label_processors is not None:
label = self.label_processors[label_idx](label)
return label
def get_labels(self, index):
return [self.get_label(index, i) for i in range(self.num_labels)]
def get_phone(self, index, label_idx):
label = self.phone_list[label_idx][index]
if self.phone_processors is not None:
label = self.phone_processors[label_idx](label)
return label
def get_phones(self, index):
return [self.get_phone(index, i) for i in range(1)]
def __getitem__(self, index):
labels = self.get_labels(index)
phone_sequence_list = self.get_phones(index)
return {"id": index, "label_list": labels, "phone_sequence_list": phone_sequence_list}
def __len__(self):
return len(self.sizes)
def collater(self, samples):
samples = [s for s in samples if s["id"] is not None]
if len(samples) == 0:
return {}
targets_by_label = [
[s["label_list"][i] for s in samples]
for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label,
)
phone_sequence_list = [s["phone_sequence_list"] for s in samples]
if phone_sequence_list[0] is None:
phone_sequence_list = None
targets_by_phone_label = [
[s["phone_sequence_list"][i] for s in samples]
for i in range(self.num_labels)
]
targets_phone_list, lengths_phone_list, ntokens_phone_list = self.collater_phone_label(
targets_by_phone_label,
)
net_input = {"source": None}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
if self.is_s2s:
batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1]
else:
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
batch["extra_text_phone_list"] = targets_phone_list
return batch
def collater_frm_label(
self, targets, label_rate, pad
):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(
targets, pad_idx=pad, left_pad=False
)
return targets, lengths, ntokens
def collater_frm_phone_label(
self, targets, pad
):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(
targets, pad_idx=pad, left_pad=False
)
return targets, lengths, ntokens
def collater_label(self, targets_by_label,):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
targets, lengths, ntokens = self.collater_frm_label(
targets, label_rate, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def collater_phone_label(self, targets_by_label):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
targets, lengths, ntokens = self.collater_frm_phone_label(
targets, pad
)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
if self.pad_audio:
return self.sizes[index]
return min(self.sizes[index], self.max_sample_size)
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)[::-1]