Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
# from tts_encode import tts_encode | |
def calculate_remaining_lengths(mel_lengths): | |
B = mel_lengths.shape[0] | |
max_L = mel_lengths.max().item() # Get the maximum length in the batch | |
# Create a range tensor: shape (max_L,), [0, 1, 2, ..., max_L-1] | |
range_tensor = torch.arange(max_L, device=mel_lengths.device).expand(B, max_L) | |
# Compute targets using broadcasting: (L-1) - range_tensor | |
remain_lengths = (mel_lengths[:, None] - 1 - range_tensor).clamp(min=0) | |
return remain_lengths | |
class PositionalEncoding(nn.Module): | |
def __init__(self, hidden_dim, max_len=4096): | |
super().__init__() | |
pe = torch.zeros(max_len, hidden_dim) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, hidden_dim, 2).float() | |
* (-torch.log(torch.tensor(10000.0)) / hidden_dim) | |
) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim) | |
def forward(self, x): | |
x = x + self.pe[:, : x.size(1)].to(x.device) | |
return x | |
class SpeechLengthPredictor(nn.Module): | |
def __init__( | |
self, | |
vocab_size=2545, | |
n_mel=100, | |
hidden_dim=256, | |
n_text_layer=4, | |
n_cross_layer=4, | |
n_head=8, | |
output_dim=1, | |
): | |
super().__init__() | |
# Text Encoder: Embedding + Transformer Layers | |
self.text_embedder = nn.Embedding( | |
vocab_size + 1, hidden_dim, padding_idx=vocab_size | |
) | |
self.text_pe = PositionalEncoding(hidden_dim) | |
encoder_layer = nn.TransformerEncoderLayer( | |
d_model=hidden_dim, | |
nhead=n_head, | |
dim_feedforward=hidden_dim * 2, | |
batch_first=True, | |
) | |
self.text_encoder = nn.TransformerEncoder( | |
encoder_layer, num_layers=n_text_layer | |
) | |
# Mel Spectrogram Embedder | |
self.mel_embedder = nn.Linear(n_mel, hidden_dim) | |
self.mel_pe = PositionalEncoding(hidden_dim) | |
# Transformer Decoder Layers with Cross-Attention in Every Layer | |
decoder_layer = nn.TransformerDecoderLayer( | |
d_model=hidden_dim, | |
nhead=n_head, | |
dim_feedforward=hidden_dim * 2, | |
batch_first=True, | |
) | |
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer) | |
# Final Classification Layer | |
self.predictor = nn.Linear(hidden_dim, output_dim) | |
def forward(self, text_ids, mel): | |
# Encode text | |
text_embedded = self.text_pe(self.text_embedder(text_ids)) | |
text_features = self.text_encoder(text_embedded) # (B, L_text, D) | |
# Encode Mel spectrogram | |
mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D) | |
# Causal Masking for Decoder | |
seq_len = mel_features.size(1) | |
causal_mask = ( | |
torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device) | |
) | |
# causal_mask = torch.triu( | |
# torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1 | |
# ) | |
# Transformer Decoder with Cross-Attention in Each Layer | |
decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask) | |
# Length Prediction | |
length_logits = self.predictor(decoder_out).squeeze(-1) | |
return length_logits | |