|
|
|
|
|
"""We want triton==2.1.0 or 2.2.0 for this |
|
""" |
|
|
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
from einops import rearrange, repeat |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE': 64}), |
|
triton.Config({'BLOCK_SIZE': 128}), |
|
triton.Config({'BLOCK_SIZE': 256}), |
|
triton.Config({'BLOCK_SIZE': 512}), |
|
triton.Config({'BLOCK_SIZE': 1024}), |
|
triton.Config({'BLOCK_SIZE': 2048}), |
|
], |
|
key=['dim'], |
|
) |
|
@triton.jit |
|
def _state_passing_fwd_kernel( |
|
|
|
states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, |
|
|
|
dim, nchunks, seqlen, chunk_size, |
|
|
|
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim, |
|
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, |
|
stride_final_states_batch, stride_final_states_head, stride_final_states_dim, |
|
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, |
|
stride_initstates_batch, stride_initstates_head, stride_initstates_dim, |
|
stride_seq_idx_batch, stride_seq_idx_seqlen, |
|
|
|
HAS_INITSTATES: tl.constexpr, |
|
HAS_SEQ_IDX: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
): |
|
pid_b = tl.program_id(axis=1) |
|
pid_h = tl.program_id(axis=2) |
|
pid_m = tl.program_id(axis=0) |
|
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head |
|
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head |
|
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head |
|
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head |
|
if HAS_INITSTATES: |
|
initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head |
|
if HAS_SEQ_IDX: |
|
seq_idx_ptr += pid_b * stride_seq_idx_batch |
|
|
|
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
|
states_ptrs = states_ptr + offs_m * stride_states_dim |
|
out_ptrs = out_ptr + offs_m * stride_out_dim |
|
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim |
|
|
|
if not HAS_INITSTATES: |
|
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) |
|
else: |
|
initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim |
|
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
tl.store(out_ptrs, states, mask=offs_m < dim) |
|
out_ptrs += stride_out_chunk |
|
seq_idx = 0 |
|
for c in range(nchunks): |
|
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
dA_cs = tl.load(dA_cs_ptr).to(tl.float32) |
|
scale = tl.exp(dA_cs) |
|
if HAS_SEQ_IDX: |
|
seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) |
|
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) |
|
seq_idx = seq_idx_new |
|
states = scale * states + new_states |
|
if c < nchunks - 1: |
|
tl.store(out_ptrs, states, mask=offs_m < dim) |
|
else: |
|
tl.store(final_states_ptrs, states, mask=offs_m < dim) |
|
states_ptrs += stride_states_chunk |
|
dA_cs_ptr += stride_dA_cs_chunk |
|
out_ptrs += stride_out_chunk |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE': 64}), |
|
triton.Config({'BLOCK_SIZE': 128}), |
|
triton.Config({'BLOCK_SIZE': 256}), |
|
triton.Config({'BLOCK_SIZE': 512}), |
|
triton.Config({'BLOCK_SIZE': 1024}), |
|
triton.Config({'BLOCK_SIZE': 2048}), |
|
], |
|
key=['dim'], |
|
) |
|
@triton.jit |
|
def _state_passing_bwd_kernel( |
|
|
|
dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr, |
|
dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr, |
|
|
|
dim, nchunks, seqlen, chunk_size, |
|
|
|
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim, |
|
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim, |
|
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, |
|
stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim, |
|
stride_seq_idx_batch, stride_seq_idx_seqlen, |
|
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim, |
|
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, |
|
stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim, |
|
|
|
CONVERT_STATES: tl.constexpr, |
|
HAS_DFINAL_STATES: tl.constexpr, |
|
HAS_DINITSTATES: tl.constexpr, |
|
HAS_SEQ_IDX: tl.constexpr, |
|
BLOCK_SIZE: tl.constexpr, |
|
): |
|
pid_b = tl.program_id(axis=1) |
|
pid_h = tl.program_id(axis=2) |
|
pid_m = tl.program_id(axis=0) |
|
dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk |
|
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk |
|
ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m |
|
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk |
|
dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk |
|
if CONVERT_STATES: |
|
states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk |
|
if HAS_DFINAL_STATES: |
|
dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head |
|
if HAS_DINITSTATES: |
|
dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head |
|
if HAS_SEQ_IDX: |
|
seq_idx_ptr += pid_b * stride_seq_idx_batch |
|
|
|
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
|
dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim |
|
out_ptrs = out_ptr + offs_m * stride_out_dim |
|
dout_ptrs = dout_ptr + offs_m * stride_dout_dim |
|
if CONVERT_STATES: |
|
states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim |
|
|
|
if HAS_DFINAL_STATES: |
|
dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
else: |
|
dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) |
|
tl.store(dstates_ptrs, dstates, mask=offs_m < dim) |
|
if HAS_SEQ_IDX: |
|
seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen) |
|
dstates_ptrs -= stride_dstates_chunk |
|
for c in range(nchunks - 1): |
|
dA_cs = tl.load(dA_cs_ptr).to(tl.float32) |
|
scale = tl.exp(dA_cs) |
|
if HAS_SEQ_IDX: |
|
seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen)) |
|
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) |
|
seq_idx = seq_idx_new |
|
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
if CONVERT_STATES: |
|
tl.store(states_converted_ptrs, out, mask=offs_m < dim) |
|
ddA = tl.sum(out * dstates) * scale |
|
tl.store(ddA_cs_ptr, ddA) |
|
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
dstates = scale * dstates + dout |
|
tl.store(dstates_ptrs, dstates, mask=offs_m < dim) |
|
dout_ptrs -= stride_dout_chunk |
|
dstates_ptrs -= stride_dstates_chunk |
|
dA_cs_ptr -= stride_dA_cs_chunk |
|
ddA_cs_ptr -= stride_ddA_cs_chunk |
|
out_ptrs -= stride_out_chunk |
|
if CONVERT_STATES: |
|
states_converted_ptrs -= stride_out_chunk |
|
if CONVERT_STATES: |
|
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
tl.store(states_converted_ptrs, out, mask=offs_m < dim) |
|
if not HAS_DINITSTATES: |
|
tl.store(ddA_cs_ptr, 0.0) |
|
else: |
|
dA_cs = tl.load(dA_cs_ptr).to(tl.float32) |
|
scale = tl.exp(dA_cs) |
|
if HAS_SEQ_IDX: |
|
scale = tl.where(seq_idx == 0, scale, 0.0) |
|
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
ddA = tl.sum(out * dstates) * scale |
|
tl.store(ddA_cs_ptr, ddA) |
|
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) |
|
dstates = scale * dstates + dout |
|
tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim) |
|
|
|
|
|
def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None, |
|
out_dtype=None): |
|
batch, nchunks, nheads, dim = states.shape |
|
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) |
|
if initial_states is not None: |
|
assert initial_states.shape == (batch, nheads, dim) |
|
if seq_idx is not None: |
|
assert chunk_size is not None |
|
seqlen = seq_idx.shape[-1] |
|
assert seq_idx.shape == (batch, seqlen) |
|
out_dtype = states.dtype if out_dtype is None else out_dtype |
|
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) |
|
final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32) |
|
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) |
|
with torch.cuda.device(states.device.index): |
|
_state_passing_fwd_kernel[grid]( |
|
states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx, |
|
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, |
|
states.stride(0), states.stride(1), states.stride(2), states.stride(3), |
|
out.stride(0), out.stride(1), out.stride(2), out.stride(3), |
|
final_states.stride(0), final_states.stride(1), final_states.stride(2), |
|
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), |
|
*((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) |
|
if initial_states is not None else (0, 0, 0)), |
|
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), |
|
HAS_INITSTATES=initial_states is not None, |
|
HAS_SEQ_IDX=seq_idx is not None, |
|
) |
|
return out, final_states |
|
|
|
|
|
def _state_passing_bwd( |
|
states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None, |
|
dstates_dtype=None, states_dtype=None, chunk_size=None |
|
): |
|
""" |
|
states contains the initial_states at index 0. The final states are not included in states. |
|
""" |
|
batch, nchunks, nheads, dim = states.shape |
|
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) |
|
assert dout.shape == (batch, nchunks, nheads, dim) |
|
if seq_idx is not None: |
|
assert chunk_size is not None |
|
seqlen = seq_idx.shape[-1] |
|
assert seq_idx.shape == (batch, seqlen) |
|
dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) |
|
if states_dtype is not None and states_dtype != states.dtype: |
|
states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype) |
|
assert states_converted.stride() == states.stride() |
|
else: |
|
states_converted = None |
|
if has_initial_states: |
|
dinitstates = torch.empty_like(dstates[:, 0]) |
|
else: |
|
dinitstates = None |
|
if dfinal_states is not None: |
|
assert dfinal_states.shape == (batch, nheads, dim) |
|
BLOCK_SIZE_min = 64 |
|
n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min |
|
ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks, |
|
dtype=torch.float32, device=dA_chunk_cumsum.device) |
|
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) |
|
with torch.cuda.device(dout.device.index): |
|
_state_passing_bwd_kernel[grid]( |
|
dout, states, dA_chunk_cumsum, dfinal_states, seq_idx, |
|
dstates, ddA_chunk_cumsum, dinitstates, states_converted, |
|
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0, |
|
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), |
|
states.stride(0), states.stride(1), states.stride(2), states.stride(3), |
|
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1), |
|
*((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2)) |
|
if dfinal_states is not None else (0, 0, 0)), |
|
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), |
|
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), |
|
ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1), |
|
*((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2)) |
|
if dinitstates is not None else (0, 0, 0)), |
|
CONVERT_STATES=states_converted is not None, |
|
HAS_DFINAL_STATES=dfinal_states is not None, |
|
HAS_DINITSTATES=dinitstates is not None, |
|
HAS_SEQ_IDX=seq_idx is not None, |
|
) |
|
BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"] |
|
n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual |
|
ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype) |
|
if states_dtype is not None and states_dtype == states.dtype: |
|
states_converted = states |
|
return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted) |
|
|
|
|
|
class StatePassingFn(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, states, dA_chunk_cumsum, initial_states=None): |
|
batch, nchunks, nheads, dim = states.shape |
|
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) |
|
if states.stride(-1) != 1: |
|
states = states.contiguous() |
|
out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states) |
|
ctx.save_for_backward(out, dA_chunk_cumsum) |
|
ctx.has_initial_states = initial_states is not None |
|
return out, final_states |
|
|
|
@staticmethod |
|
def backward(ctx, dout, dfinal_states): |
|
out, dA_chunk_cumsum = ctx.saved_tensors |
|
batch, nchunks, nheads, dim = out.shape |
|
assert dout.shape == (batch, nchunks, nheads, dim) |
|
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) |
|
assert dfinal_states.shape == (batch, nheads, dim) |
|
if dout.stride(-1) != 1: |
|
dout = dout.contiguous() |
|
dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd( |
|
out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states |
|
) |
|
return dstates, ddA_chunk_cumsum, dinitstates |
|
|
|
|
|
def state_passing(states, dA_chunk_cumsum, initial_states=None): |
|
""" |
|
Argument: |
|
states: (batch, nchunks, nheads, dim) |
|
dA_chunk_cumsum: (batch, nheads, nchunks) |
|
initial_states: (batch, nheads, dim) |
|
Return: |
|
out: (batch, nchunks, nheads, dim) |
|
final_states: (batch, nheads, dim) |
|
""" |
|
return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states) |
|
|
|
|
|
def state_passing_ref(states, dA_chunk_cumsum, initial_states=None): |
|
""" |
|
Argument: |
|
states: (batch, nchunks, nheads, dim) |
|
dA_chunk_cumsum: (batch, nheads, nchunks) |
|
initial_states: (batch, nheads, dim) |
|
Return: |
|
out: (batch, nchunks, nheads, dim) |
|
final_states: (batch, nheads, dim) |
|
""" |
|
if initial_states is None: |
|
initial_states = torch.zeros_like(states[:, 0]) |
|
states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1) |
|
dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0)) |
|
dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1) |
|
nchunks = dA_chunk_cumsum.shape[-1] |
|
|
|
dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :] |
|
|
|
decay_chunk = torch.exp(dt_chunk_segment_sum) |
|
causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0) |
|
decay_chunk = decay_chunk.masked_fill(~causal_mask, 0) |
|
out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states) |
|
return out[:, :-1], out[:, -1] |
|
|