Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
def create_tts_mask(seq_len, max_seq_len, mask_range): | |
bs = seq_len.size(0) | |
device = seq_len.device | |
# 1. Sample random fractional lengths for each sequence | |
frac_lengths = torch.zeros(bs, device=device).uniform_(*mask_range) | |
# 2. Convert fractional lengths to integer lengths | |
lengths = (frac_lengths * seq_len).long() | |
# 3. Compute valid start indices based on sequence length | |
max_start = seq_len - lengths | |
# 4. Sample random start positions (clamped at 0) | |
rand = torch.rand(bs, device=device) | |
start = (max_start * rand).long().clamp(min=0) | |
end = start + lengths | |
# 5. Build the final boolean mask | |
# max_seq_len = seq_len.max().item() | |
seq = torch.arange(max_seq_len, device=device).long() | |
start_mask = seq[None, :] >= start[:, None] | |
end_mask = seq[None, :] < end[:, None] | |
mask = start_mask & end_mask | |
return mask | |
if __name__ == "__main__": | |
# Example: 3 sequences of lengths [5, 7, 6] | |
lengths = torch.tensor([5, 7, 6]) | |
mask_range = (0.7, 1.0) # Sample fractional lengths between 30% and 70% of each seq | |
mask = create_tts_mask(lengths, mask_range) | |
print("Mask shape:", mask.shape) # Should be [3, 7], since max_seq_len is 7 | |
print(mask) |