|
|
|
|
|
|
|
from typing import Optional, Union |
|
|
|
import torch |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
@triton.jit |
|
def rotary_kernel( |
|
OUT, |
|
X, |
|
COS, |
|
SIN, |
|
CU_SEQLENS, |
|
SEQLEN_OFFSETS, |
|
|
|
seqlen, |
|
nheads, |
|
seqlen_ro, |
|
|
|
stride_out_batch, |
|
stride_out_seqlen, |
|
stride_out_nheads, |
|
stride_out_headdim, |
|
stride_x_batch, |
|
stride_x_seqlen, |
|
stride_x_nheads, |
|
stride_x_headdim, |
|
|
|
|
|
|
|
ROTARY_DIM: tl.constexpr, |
|
IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, |
|
IS_VARLEN: tl.constexpr, |
|
INTERLEAVED: tl.constexpr, |
|
CONJUGATE: tl.constexpr, |
|
BLOCK_H: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
): |
|
BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) |
|
ROTARY_DIM_HALF = ROTARY_DIM // 2 |
|
pid_head = tl.program_id(axis=0) |
|
pid_m = tl.program_id(axis=1) |
|
pid_batch = tl.program_id(axis=2) |
|
|
|
if not IS_VARLEN: |
|
X = X + pid_batch * stride_x_batch |
|
OUT = OUT + pid_batch * stride_out_batch |
|
else: |
|
start_idx = tl.load(CU_SEQLENS + pid_batch) |
|
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx |
|
X = X + start_idx * stride_x_seqlen |
|
OUT = OUT + start_idx * stride_out_seqlen |
|
|
|
if pid_m * BLOCK_M >= seqlen: |
|
return |
|
|
|
rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) |
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
if not IS_SEQLEN_OFFSETS_TENSOR: |
|
rm_cs = rm + SEQLEN_OFFSETS |
|
else: |
|
rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) |
|
|
|
rk_half = tl.arange(0, BLOCK_K // 2) |
|
COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) |
|
SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) |
|
mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) |
|
cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) |
|
sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) |
|
if CONJUGATE: |
|
sin = -sin |
|
|
|
if not INTERLEAVED: |
|
|
|
X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) |
|
OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) |
|
mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) |
|
x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) |
|
x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0,).to(tl.float32) |
|
o0 = x0 * cos - x1 * sin |
|
o1 = x0 * sin + x1 * cos |
|
tl.store(OUT, o0, mask=mask) |
|
tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) |
|
else: |
|
rk = tl.arange(0, BLOCK_K) |
|
X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) |
|
OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) |
|
mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) |
|
x = tl.load(X, mask=mask, other=0.0).to(tl.float32) |
|
x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) |
|
o0 = x0 * cos - x1 * sin |
|
o1 = x0 * sin + x1 * cos |
|
o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) |
|
tl.store(OUT, o, mask=mask) |
|
|
|
|
|
def apply_rotary( |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
seqlen_offsets: Union[int, torch.Tensor] = 0, |
|
cu_seqlens: Optional[torch.Tensor] = None, |
|
max_seqlen: Optional[int] = None, |
|
interleaved=False, |
|
inplace=False, |
|
conjugate=False, |
|
) -> torch.Tensor: |
|
""" |
|
Arguments: |
|
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None |
|
else (total_seqlen, nheads, headdim). |
|
cos: (seqlen_ro, rotary_dim / 2) |
|
sin: (seqlen_ro, rotary_dim / 2) |
|
seqlen_offsets: integer or integer tensor of size (batch,) |
|
cu_seqlens: (batch + 1,) or None |
|
max_seqlen: int |
|
Returns: |
|
y: (batch, seqlen, nheads, headdim) |
|
""" |
|
is_varlen = cu_seqlens is not None |
|
if not is_varlen: |
|
batch, seqlen, nheads, headdim = x.shape |
|
else: |
|
assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" |
|
total_seqlen, nheads, headdim = x.shape |
|
batch_p_1 = cu_seqlens.shape[0] |
|
batch = batch_p_1 - 1 |
|
seqlen = max_seqlen |
|
seqlen_ro, rotary_dim = cos.shape |
|
assert sin.shape == cos.shape |
|
rotary_dim *= 2 |
|
assert rotary_dim <= headdim, "rotary_dim must be <= headdim" |
|
assert headdim <= 256, "Only support headdim <= 256" |
|
assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" |
|
|
|
cos, sin = cos.contiguous(), sin.contiguous() |
|
if isinstance(seqlen_offsets, torch.Tensor): |
|
assert seqlen_offsets.shape == (batch,) |
|
assert seqlen_offsets.dtype in [torch.int32, torch.int64] |
|
seqlen_offsets = seqlen_offsets.contiguous() |
|
else: |
|
assert seqlen_offsets + seqlen <= seqlen_ro |
|
|
|
output = torch.empty_like(x) if not inplace else x |
|
if rotary_dim < headdim and not inplace: |
|
output[..., rotary_dim:].copy_(x[..., rotary_dim:]) |
|
|
|
grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) |
|
BLOCK_M = 8 if rotary_dim <= 128 else 4 |
|
|
|
|
|
|
|
with torch.cuda.device(x.device.index): |
|
torch.library.wrap_triton(rotary_kernel)[grid]( |
|
output, |
|
x, |
|
cos, |
|
sin, |
|
cu_seqlens, |
|
seqlen_offsets, |
|
seqlen, |
|
nheads, |
|
seqlen_ro, |
|
output.stride(0) if not is_varlen else 0, |
|
output.stride(-3), |
|
output.stride(-2), |
|
output.stride(-1), |
|
x.stride(0) if not is_varlen else 0, |
|
x.stride(-3), |
|
x.stride(-2), |
|
x.stride(-1), |
|
rotary_dim, |
|
isinstance(seqlen_offsets, torch.Tensor), |
|
is_varlen, |
|
interleaved, |
|
conjugate, |
|
BLOCK_M=BLOCK_M, |
|
BLOCK_H=2, |
|
) |
|
return output |
|
|