File size: 7,088 Bytes
09eec95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright (c) 2025, Tri Dao.
# As of 2025-04-23, we require triton >= 3.0
from typing import Optional, Union
import torch
import triton
import triton.language as tl
@triton.jit
def rotary_kernel(
OUT, # Pointers to matrices
X,
COS,
SIN,
CU_SEQLENS,
SEQLEN_OFFSETS, # this could be int or a pointer
# Matrix dimensions
seqlen,
nheads,
seqlen_ro,
# strides
stride_out_batch,
stride_out_seqlen,
stride_out_nheads,
stride_out_headdim,
stride_x_batch,
stride_x_seqlen,
stride_x_nheads,
stride_x_headdim,
# Meta-parameters
# We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that
# the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128
ROTARY_DIM: tl.constexpr,
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
IS_VARLEN: tl.constexpr,
INTERLEAVED: tl.constexpr,
CONJUGATE: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_M: tl.constexpr,
):
BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM)
ROTARY_DIM_HALF = ROTARY_DIM // 2
pid_head = tl.program_id(axis=0)
pid_m = tl.program_id(axis=1)
pid_batch = tl.program_id(axis=2)
if not IS_VARLEN:
X = X + pid_batch * stride_x_batch
OUT = OUT + pid_batch * stride_out_batch
else:
start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
X = X + start_idx * stride_x_seqlen
OUT = OUT + start_idx * stride_out_seqlen
if pid_m * BLOCK_M >= seqlen:
return
rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
if not IS_SEQLEN_OFFSETS_TENSOR:
rm_cs = rm + SEQLEN_OFFSETS
else:
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
rk_half = tl.arange(0, BLOCK_K // 2)
COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])
SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :])
mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF)
cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32)
sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32)
if CONJUGATE:
sin = -sin
if not INTERLEAVED:
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim)
OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim)
mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF)
x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32)
x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32)
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
tl.store(OUT, o0, mask=mask)
tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask)
else:
rk = tl.arange(0, BLOCK_K)
X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim)
OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim)
mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM)
x = tl.load(X, mask=mask, other=0.0).to(tl.float32)
x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2]))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K])
tl.store(OUT, o, mask=mask)
def apply_rotary(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
interleaved=False,
inplace=False,
conjugate=False,
) -> torch.Tensor:
"""
Arguments:
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
is_varlen = cu_seqlens is not None
if not is_varlen:
batch, seqlen, nheads, headdim = x.shape
else:
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen, nheads, headdim = x.shape
batch_p_1 = cu_seqlens.shape[0]
batch = batch_p_1 - 1
seqlen = max_seqlen
seqlen_ro, rotary_dim = cos.shape
assert sin.shape == cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
assert headdim <= 256, "Only support headdim <= 256"
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
cos, sin = cos.contiguous(), sin.contiguous()
if isinstance(seqlen_offsets, torch.Tensor):
assert seqlen_offsets.shape == (batch,)
assert seqlen_offsets.dtype in [torch.int32, torch.int64]
seqlen_offsets = seqlen_offsets.contiguous()
else:
assert seqlen_offsets + seqlen <= seqlen_ro
output = torch.empty_like(x) if not inplace else x
if rotary_dim < headdim and not inplace:
output[..., rotary_dim:].copy_(x[..., rotary_dim:])
grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa
BLOCK_M = 8 if rotary_dim <= 128 else 4
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
torch.library.wrap_triton(rotary_kernel)[grid](
output, # data ptrs
x,
cos,
sin,
cu_seqlens,
seqlen_offsets,
seqlen, # shapes
nheads,
seqlen_ro,
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
output.stride(-3), # seqlen_stride or total_seqlen_stride
output.stride(-2), # nheads_stride
output.stride(-1), # headdim_stride
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(-3), # seqlen stride or total_seqlen_stride
x.stride(-2), # nheads stride
x.stride(-1), # headdim stride
rotary_dim,
isinstance(seqlen_offsets, torch.Tensor),
is_varlen,
interleaved,
conjugate,
BLOCK_M=BLOCK_M,
BLOCK_H=2,
)
return output
|