OpenSound's picture
Upload 518 files
dd9600d verified
import torch
def create_tts_mask(seq_len, max_seq_len, mask_range):
bs = seq_len.size(0)
device = seq_len.device
# 1. Sample random fractional lengths for each sequence
frac_lengths = torch.zeros(bs, device=device).uniform_(*mask_range)
# 2. Convert fractional lengths to integer lengths
lengths = (frac_lengths * seq_len).long()
# 3. Compute valid start indices based on sequence length
max_start = seq_len - lengths
# 4. Sample random start positions (clamped at 0)
rand = torch.rand(bs, device=device)
start = (max_start * rand).long().clamp(min=0)
end = start + lengths
# 5. Build the final boolean mask
# max_seq_len = seq_len.max().item()
seq = torch.arange(max_seq_len, device=device).long()
start_mask = seq[None, :] >= start[:, None]
end_mask = seq[None, :] < end[:, None]
mask = start_mask & end_mask
return mask
if __name__ == "__main__":
# Example: 3 sequences of lengths [5, 7, 6]
lengths = torch.tensor([5, 7, 6])
mask_range = (0.7, 1.0) # Sample fractional lengths between 30% and 70% of each seq
mask = create_tts_mask(lengths, mask_range)
print("Mask shape:", mask.shape) # Should be [3, 7], since max_seq_len is 7
print(mask)