File size: 1,259 Bytes
dd9600d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch


def create_tts_mask(seq_len, max_seq_len, mask_range):

    bs = seq_len.size(0)
    device = seq_len.device

    # 1. Sample random fractional lengths for each sequence
    frac_lengths = torch.zeros(bs, device=device).uniform_(*mask_range)

    # 2. Convert fractional lengths to integer lengths
    lengths = (frac_lengths * seq_len).long()

    # 3. Compute valid start indices based on sequence length
    max_start = seq_len - lengths

    # 4. Sample random start positions (clamped at 0)
    rand = torch.rand(bs, device=device)
    start = (max_start * rand).long().clamp(min=0)
    end = start + lengths

    # 5. Build the final boolean mask
    # max_seq_len = seq_len.max().item()
    seq = torch.arange(max_seq_len, device=device).long()

    start_mask = seq[None, :] >= start[:, None]
    end_mask = seq[None, :] <  end[:, None]
    mask = start_mask & end_mask

    return mask


if __name__ == "__main__":
    # Example: 3 sequences of lengths [5, 7, 6]
    lengths = torch.tensor([5, 7, 6])
    mask_range = (0.7, 1.0)  # Sample fractional lengths between 30% and 70% of each seq

    mask = create_tts_mask(lengths, mask_range)
    print("Mask shape:", mask.shape)  # Should be [3, 7], since max_seq_len is 7
    print(mask)