DMOSpeech2 / duration_predictor.py
mrfakename's picture
pt 1
597cecf
import torch
import torch.nn as nn
# from tts_encode import tts_encode
def calculate_remaining_lengths(mel_lengths):
B = mel_lengths.shape[0]
max_L = mel_lengths.max().item() # Get the maximum length in the batch
# Create a range tensor: shape (max_L,), [0, 1, 2, ..., max_L-1]
range_tensor = torch.arange(max_L, device=mel_lengths.device).expand(B, max_L)
# Compute targets using broadcasting: (L-1) - range_tensor
remain_lengths = (mel_lengths[:, None] - 1 - range_tensor).clamp(min=0)
return remain_lengths
class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2).float()
* (-torch.log(torch.tensor(10000.0)) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim)
def forward(self, x):
x = x + self.pe[:, : x.size(1)].to(x.device)
return x
class SpeechLengthPredictor(nn.Module):
def __init__(
self,
vocab_size=2545,
n_mel=100,
hidden_dim=256,
n_text_layer=4,
n_cross_layer=4,
n_head=8,
output_dim=1,
):
super().__init__()
# Text Encoder: Embedding + Transformer Layers
self.text_embedder = nn.Embedding(
vocab_size + 1, hidden_dim, padding_idx=vocab_size
)
self.text_pe = PositionalEncoding(hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=n_head,
dim_feedforward=hidden_dim * 2,
batch_first=True,
)
self.text_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_text_layer
)
# Mel Spectrogram Embedder
self.mel_embedder = nn.Linear(n_mel, hidden_dim)
self.mel_pe = PositionalEncoding(hidden_dim)
# Transformer Decoder Layers with Cross-Attention in Every Layer
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=n_head,
dim_feedforward=hidden_dim * 2,
batch_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
# Final Classification Layer
self.predictor = nn.Linear(hidden_dim, output_dim)
def forward(self, text_ids, mel):
# Encode text
text_embedded = self.text_pe(self.text_embedder(text_ids))
text_features = self.text_encoder(text_embedded) # (B, L_text, D)
# Encode Mel spectrogram
mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D)
# Causal Masking for Decoder
seq_len = mel_features.size(1)
causal_mask = (
torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
)
# causal_mask = torch.triu(
# torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
# )
# Transformer Decoder with Cross-Attention in Each Layer
decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
# Length Prediction
length_logits = self.predictor(decoder_out).squeeze(-1)
return length_logits