Zihan428's picture
Clean multilingual TTS repo
f79db70
raw
history blame
3.29 kB
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn, Tensor
from .perceiver import Perceiver
from .t3_config import T3Config
@dataclass
class T3Cond:
"""
Dataclass container for most / all conditioning info.
TODO: serialization methods aren't used, keeping them around for convenience
"""
speaker_emb: Tensor
clap_emb: Optional[Tensor] = None
cond_prompt_speech_tokens: Optional[Tensor] = None
cond_prompt_speech_emb: Optional[Tensor] = None
emotion_adv: Optional[Tensor] = 0.5
def to(self, *, device=None, dtype=None):
"Cast to a device and dtype. Dtype casting is ignored for long/int tensors."
for k, v in self.__dict__.items():
if torch.is_tensor(v):
is_fp = type(v.view(-1)[0].item()) is not int
setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None))
return self
def save(self, fpath):
torch.save(self.__dict__, fpath)
@staticmethod
def load(fpath, map_location="cpu"):
kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
return T3Cond(**kwargs)
class T3CondEnc(nn.Module):
"""
Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc.
"""
def __init__(self, hp: T3Config):
super().__init__()
self.hp = hp
if hp.encoder_type == "voice_encoder":
self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels)
else:
raise NotImplementedError(str(hp.encoder_type))
# emotion adv
self.emotion_adv_fc = None
if hp.emotion_adv:
self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False)
# perceiver resampler
self.perceiver = None
if hp.use_perceiver_resampler:
self.perceiver = Perceiver()
def forward(self, cond: T3Cond):
# Validate
assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \
"no embeddings for cond_prompt_speech_tokens"
# Speaker embedding projection
cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim)
empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim)
# TODO CLAP
assert cond.clap_emb is None, "clap_embed not implemented"
cond_clap = empty # (B, 0, dim)
# Cond prompt
cond_prompt_speech_emb = cond.cond_prompt_speech_emb
if cond_prompt_speech_emb is None:
cond_prompt_speech_emb = empty # (B, 0, dim)
elif self.hp.use_perceiver_resampler:
cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb)
# Emotion Adv: must provide a value if this model uses emotion conditioning
cond_emotion_adv = empty # (B, 0, dim)
if self.hp.emotion_adv:
assert cond.emotion_adv is not None
cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1))
# Concat and return
cond_embeds = torch.cat((
cond_spkr,
cond_clap,
cond_prompt_speech_emb,
cond_emotion_adv,
), dim=1)
return cond_embeds