File size: 9,258 Bytes
8ae6c69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
from typing import Callable, Optional, Tuple, Union
from dataclasses import dataclass

import torch
from torch import nn

class HomogeneousSequential(nn.Sequential):
    """
    HomogenousSequential is a sequential container that requires all child modules
    to be of the same type and have matching input/output shapes. In turn, it may be
    compiled with the `scan` higher order operator to save compile time.
    """

    repeated_layer: type
    """The type of the layer being looped over."""

    def __init__(self, *args: nn.Module) -> None:
        super().__init__(*args)
        types = set(type(module) for module in args)
        assert len(types) == 1, f"All modules must be of the same type. Got {types}"
        self.repeated_layer = types.pop()

    def forward(self, *input, **broadcasted_inputs):
        """
        Much like `torch.nn.Sequential`, this takes `input` and forwards it to the
        first module it contains. It then "chains" outputs to inputs sequentially for
        each subsequent module, finally returning the output of the last module.
        Different from `torch.nn.Sequential`, you may specify `broadcasted_inputs` via
        keyword arguments. The same keyword arguments will be passed to every layer
        without changes (i.e. "broadcasted").
        """
        for module in self:
            input = module(*splat(input), **broadcasted_inputs)
        return input


def splat(input):
    if not isinstance(input, list | tuple):
        input = (input,)
    return input


@dataclass(kw_only=True)
class RopeScaling:
    """
    RoPE scaling parameters. The defaults are what was selected in Llama 3.1.
    """
    factor: float = 8.0
    low_freq_factor: float = 1.0
    high_freq_factor: float = 4.0
    original_context_len: int = 8192


def default_rope_frequencies(
    head_dim: int,
    theta: float = 10000.0,
) -> torch.Tensor:
    """
    Computes the original RoPE frequencies in e.g. Llama 2.
    Args:
        head_dim: the size of a single attention head.
        theta: a hyperparameter controlling how fast the embeddings rotate.
    Returns:
        The frequencies for the RoPE embeddings.
    """
    return 1.0 / (
        theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim)
    )

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
      q (`torch.Tensor`): The query tensor.
      k (`torch.Tensor`): The key tensor.
      cos (`torch.Tensor`): The cosine part of the rotary embedding.
      sin (`torch.Tensor`): The sine part of the rotary embedding.
      position_ids (`torch.Tensor`, *optional*):
        Deprecated and unused.
      unsqueeze_dim (`int`, *optional*, defaults to 1):
        The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
        sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
        that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
        k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
        cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
        the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
      `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed



def transition(x_0, sigma, maskable_mask, mask_token_id, mask_block_size: int = 1):
    """Apply masking to input tokens. If mask_block_size > 1, use block masking for all rows."""

    if mask_block_size == 1:
        # Original behavior
        # weiran: diffullama
        move_indices = (
            torch.rand(*x_0.shape, device=x_0.device) < sigma
        ) & maskable_mask
        x_t = torch.where(move_indices, mask_token_id, x_0)
        return x_t

    # Block masking for entire batch
    return block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size)


def block_masking(x_0, sigma, maskable_mask, mask_token_id, mask_block_size):
    """
    XLA-compatible block masking applied uniformly to all rows in the batch.
    Uses efficient tensor operations to avoid dynamic loops.
    """
    batch_size, seq_len = x_0.shape

    if seq_len < mask_block_size:
        return x_0

    # Calculate number of possible block positions
    num_windows = seq_len - mask_block_size + 1

    # Create all possible block positions: [num_windows, mask_block_size]
    window_starts = torch.arange(num_windows, device=x_0.device)
    block_offsets = torch.arange(mask_block_size, device=x_0.device)
    all_positions = window_starts.unsqueeze(1) + block_offsets.unsqueeze(0)

    # Check which blocks are fully maskable: [batch_size, num_windows]
    maskable_blocks = (
        maskable_mask.unsqueeze(1)
        .expand(-1, num_windows, -1)
        .gather(2, all_positions.unsqueeze(0).expand(batch_size, -1, -1))
    )
    fully_maskable = maskable_blocks.all(dim=2)

    # Determine which blocks should be masked: (batch_size, num_windows)
    effective_sigma = 1 - (1 - sigma) ** (
        1 / mask_block_size
    )  # NOTE: since we mask with blocks, we need to scale sigma by block size
    should_mask = (
        torch.rand(batch_size, num_windows, device=x_0.device) < effective_sigma
    ) & fully_maskable

    # Create final mask using simple broadcasting (fully XLA-compatible)
    # For each position in the sequence, check if it's part of any masked block
    position_indices = torch.arange(seq_len, device=x_0.device)  # [seq_len]

    # Check for each position if it falls within any masked block
    # position_indices: [seq_len] -> [1, 1, seq_len]
    # all_positions: [num_windows, mask_block_size] -> [1, num_windows, mask_block_size]
    # should_mask: [batch_size, num_windows] -> [batch_size, num_windows, 1]

    position_indices = position_indices.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len]
    all_positions = all_positions.unsqueeze(0)  # [1, num_windows, mask_block_size]
    should_mask = should_mask.unsqueeze(2)  # [batch_size, num_windows, 1]

    # Check if each position matches any of the positions in masked blocks
    # [1, 1, seq_len] == [1, num_windows, mask_block_size] -> [1, num_windows, seq_len]
    position_matches = (position_indices == all_positions.unsqueeze(3)).any(
        dim=2
    )  # [1, num_windows, seq_len]

    # Apply should_mask to get final positions to mask
    # [batch_size, num_windows, 1] & [1, num_windows, seq_len] -> [batch_size, num_windows, seq_len]
    should_mask_positions = should_mask & position_matches

    # Reduce over windows: if any window masks this position, mask it
    final_mask = should_mask_positions.any(dim=1)  # [batch_size, seq_len]

    # Apply the mask
    result = torch.where(final_mask, mask_token_id, x_0)

    return result


def prefix_input_ids(input_ids, maskable_mask, apply_prefix):
    """Apply prefix to input_ids based on configured probability. Return a masksable mask such that the prefix is not masked."""
    batch_size, seq_len = input_ids.shape
    # Generate random prefix lengths for all batch items
    prefix_lengths = torch.randint(1, seq_len, (batch_size,), device=input_ids.device)
    # Create position indices: [1, seq_len]
    position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(
        0
    )  # [1, seq_len]
    # Create prefix mask: True where position < prefix_length
    prefix_mask = position_indices < prefix_lengths.unsqueeze(
        1
    )  # [batch_size, seq_len]
    # Apply prefix masking: set to False where we should apply prefix masking
    maskable_mask = maskable_mask & ~(apply_prefix.unsqueeze(1) & prefix_mask)
    return maskable_mask


def truncate_input_ids(input_ids, apply_truncate, pad_token_id):
    """Truncate input_ids at random position and fill with pad token. Return the input_ids with suffix truncated and filled with pad token."""
    batch_size, seq_len = input_ids.shape
    # Generate random truncation positions for all batch items
    truncate_positions = torch.randint(
        1, seq_len, (batch_size,), device=input_ids.device
    )
    # Create position indices: [1, seq_len]
    position_indices = torch.arange(seq_len, device=input_ids.device).unsqueeze(
        0
    )  # [1, seq_len]
    # Create truncate mask: True where position >= truncate_position
    truncate_mask = position_indices >= truncate_positions.unsqueeze(
        1
    )  # [batch_size, seq_len]
    # Apply truncation: fill with pad token where we should truncate
    input_ids = torch.where(
        apply_truncate.unsqueeze(1) & truncate_mask, pad_token_id, input_ids
    )
    return input_ids