File size: 3,352 Bytes
407412c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn

# from tts_encode import tts_encode

def calculate_remaining_lengths(mel_lengths):
    B = mel_lengths.shape[0]
    max_L = mel_lengths.max().item()  # Get the maximum length in the batch

    # Create a range tensor: shape (max_L,), [0, 1, 2, ..., max_L-1]
    range_tensor = torch.arange(max_L, device=mel_lengths.device).expand(B, max_L)

    # Compute targets using broadcasting: (L-1) - range_tensor
    remain_lengths = (mel_lengths[:, None] - 1 - range_tensor).clamp(min=0)

    return remain_lengths


class PositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, max_len=4096):
        super().__init__()
        pe = torch.zeros(max_len, hidden_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / hidden_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # Shape: (1, max_len, hidden_dim)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].to(x.device)
        return x


class SpeechLengthPredictor(nn.Module):

    def __init__(self, 
        vocab_size=2545, n_mel=100, hidden_dim=256, 
        n_text_layer=4, n_cross_layer=4, n_head=8,
        output_dim=1,
    ):
        super().__init__()
        
        # Text Encoder: Embedding + Transformer Layers
        self.text_embedder = nn.Embedding(vocab_size+1, hidden_dim, padding_idx=vocab_size)
        self.text_pe = PositionalEncoding(hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True
        )
        self.text_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_text_layer)
        
        # Mel Spectrogram Embedder
        self.mel_embedder = nn.Linear(n_mel, hidden_dim)
        self.mel_pe = PositionalEncoding(hidden_dim)

        # Transformer Decoder Layers with Cross-Attention in Every Layer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim, nhead=n_head, dim_feedforward=hidden_dim*2, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=n_cross_layer)
        
        # Final Classification Layer
        self.predictor = nn.Linear(hidden_dim, output_dim)

    def forward(self, text_ids, mel):
        # Encode text
        text_embedded = self.text_pe(self.text_embedder(text_ids))
        text_features = self.text_encoder(text_embedded) # (B, L_text, D)
        
        # Encode Mel spectrogram
        mel_features = self.mel_pe(self.mel_embedder(mel))  # (B, L_mel, D)
        
        # Causal Masking for Decoder
        seq_len = mel_features.size(1)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(mel.device)
        # causal_mask = torch.triu(
        #     torch.full((seq_len, seq_len), float('-inf'), device=mel.device), diagonal=1
        # )

        # Transformer Decoder with Cross-Attention in Each Layer
        decoder_out = self.decoder(mel_features, text_features, tgt_mask=causal_mask)
        
        # Length Prediction
        length_logits = self.predictor(decoder_out).squeeze(-1)
        return length_logits