|
from typing import Callable, Dict, List, Sequence, Union |
|
|
|
import torch |
|
from lhotse import CutSet, validate |
|
from lhotse.dataset import PrecomputedFeatures |
|
from lhotse.dataset.collation import collate_audio |
|
from lhotse.dataset.input_strategies import BatchIO |
|
from lhotse.utils import ifnone |
|
|
|
|
|
class SpeechSynthesisDataset(torch.utils.data.Dataset): |
|
""" |
|
The PyTorch Dataset for the speech synthesis task. |
|
Each item in this dataset is a dict of: |
|
|
|
.. code-block:: |
|
|
|
{ |
|
'audio': (B x NumSamples) float tensor |
|
'features': (B x NumFrames x NumFeatures) float tensor |
|
'audio_lens': (B, ) int tensor |
|
'features_lens': (B, ) int tensor |
|
'text': List[str] of len B # when return_text=True |
|
'tokens': List[List[str]] # when return_tokens=True |
|
'speakers': List[str] of len B # when return_spk_ids=True |
|
'cut': List of Cuts # when return_cuts=True |
|
} |
|
""" |
|
|
|
def __init__( |
|
self, |
|
cut_transforms: List[Callable[[CutSet], CutSet]] = None, |
|
feature_input_strategy: BatchIO = PrecomputedFeatures(), |
|
feature_transforms: Union[Sequence[Callable], Callable] = None, |
|
return_text: bool = True, |
|
return_tokens: bool = False, |
|
return_spk_ids: bool = False, |
|
return_cuts: bool = False, |
|
return_audio: bool = False, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.cut_transforms = ifnone(cut_transforms, []) |
|
self.feature_input_strategy = feature_input_strategy |
|
|
|
self.return_text = return_text |
|
self.return_tokens = return_tokens |
|
self.return_spk_ids = return_spk_ids |
|
self.return_cuts = return_cuts |
|
self.return_audio = return_audio |
|
|
|
if feature_transforms is None: |
|
feature_transforms = [] |
|
elif not isinstance(feature_transforms, Sequence): |
|
feature_transforms = [feature_transforms] |
|
|
|
assert all( |
|
isinstance(transform, Callable) for transform in feature_transforms |
|
), "Feature transforms must be Callable" |
|
self.feature_transforms = feature_transforms |
|
|
|
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: |
|
validate_for_tts(cuts) |
|
|
|
for transform in self.cut_transforms: |
|
cuts = transform(cuts) |
|
|
|
features, features_lens = self.feature_input_strategy(cuts) |
|
|
|
for transform in self.feature_transforms: |
|
features = transform(features) |
|
|
|
batch = { |
|
"features": features, |
|
"features_lens": features_lens, |
|
} |
|
|
|
if self.return_audio: |
|
audio, audio_lens = collate_audio(cuts) |
|
batch["audio"] = audio |
|
batch["audio_lens"] = audio_lens |
|
|
|
if self.return_text: |
|
text = [cut.supervisions[0].text for cut in cuts] |
|
batch["text"] = text |
|
|
|
if self.return_tokens: |
|
tokens = [cut.supervisions[0].tokens for cut in cuts] |
|
batch["tokens"] = tokens |
|
|
|
if self.return_spk_ids: |
|
batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts] |
|
|
|
if self.return_cuts: |
|
batch["cut"] = [cut for cut in cuts] |
|
|
|
return batch |
|
|
|
|
|
def validate_for_tts(cuts: CutSet) -> None: |
|
validate(cuts) |
|
for cut in cuts: |
|
assert ( |
|
len(cut.supervisions) == 1 |
|
), "Only the Cuts with single supervision are supported." |
|
|