Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,528 Bytes
39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 597cecf 39d2f14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
import torch.nn as nn
# from tts_encode import tts_encode
def calculate_remaining_lengths(mel_lengths):
B = mel_lengths.shape[0]
max_L = mel_lengths.max().item() # Get the maximum length in the batch
# Create a range tensor: shape (max_L,), [0, 1, 2, ..., max_L-1]
range_tensor = torch.arange(max_L, device=mel_lengths.device).expand(B, max_L)
# Compute targets using broadcasting: (L-1) - range_tensor
remain_lengths = (mel_lengths[:, None] - 1 - range_tensor).clamp(min=0)
return remain_lengths
class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, hidden_dim, 2).float()
* (-torch.log(torch.tensor(10000.0)) / hidden_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.pe = pe.unsqueeze(0) # Shape: (1, max_len, hidden_dim)
def forward(self, x):
x = x + self.pe[:, : x.size(1)].to(x.device)
return x
class SpeechLengthPredictor(nn.Module):
def __init__(
self,
vocab_size=2545,
n_mel=100,
hidden_dim=256,
n_text_layer=4,
n_cross_layer=4,
n_head=8,
output_dim=1,
):
super().__init__()
# Text Encoder: Embedding + Transformer Layers
self.text_embedder = nn.Embedding(
vocab_size + 1, hidden_dim, padding_idx=vocab_size
)
self.text_pe = PositionalEncoding(hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=n_head,
dim_feedforward=hidden_dim * 2,
batch_first=True,
)
self.text_encoder = nn.TransformerEncoder(
encoder_layer, num_layers=n_text_layer
)
# Mel Spectrogram Embedder
self.mel_embedder = nn.Linear(n_mel, hidden_dim)
self.mel_pe = PositionalEncoding(hidden_dim)
# Transformer Decoder Layers with Cross-Attention in Every Layer
decoder_layer = nn.TransformerDecoderLayer(
d_model=hidden_dim,
nhead=n_head,
dim_feedforward=hidden_dim * 2,
batch_first=True,
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
# Final Classification Layer
self.predictor = nn.Linear(hidden_dim, output_dim)
def forward(self, text_ids, mel):
# Encode text
text_embedded = self.text_pe(self.text_embedder(text_ids))
text_features = self.text_encoder(text_embedded) # (B, L_text, D)
# Encode Mel spectrogram
mel_features = self.mel_pe(self.mel_embedder(mel)) # (B, L_mel, D)
# Causal Masking for Decoder
seq_len = mel_features.size(1)
causal_mask = (
torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
)
# causal_mask = torch.triu(
# torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
# )
# Transformer Decoder with Cross-Attention in Each Layer
decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
# Length Prediction
length_logits = self.predictor(decoder_out).squeeze(-1)
return length_logits
|