# Copyright (c) 2024, Tri Dao, Albert Gu. """We want triton==2.1.0 or 2.2.0 for this """ import math from packaging import version import torch import torch.nn.functional as F import triton import triton.language as tl from einops import rearrange, repeat from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') def init_to_zero(names): return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_D_head, # Meta-parameters IS_CAUSAL: tl.constexpr, HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, IS_TRITON_22: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. # With Triton 2.2.0, this works if IS_TRITON_22 or pid_c > -1: # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) if not HAS_SEQ_IDX: scale_m = tl.exp(dA_cs_m) else: scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) dt_ptrs = dt_ptr + offs_k * stride_dt_csize dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: mask = offs_m[:, None] >= k + offs_k[None, :] cb = tl.where(mask, cb, 0.0) cb = cb.to(x_ptr.dtype.element_ty) x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) acc += tl.dot(cb, x) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k x_ptrs += BLOCK_SIZE_K * stride_x_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) acc += x_residual * D if HAS_Z: out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) @triton.autotune( configs=[ # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), ], key=['chunk_size', 'hdim', 'dstate'], ) @triton.jit def _chunk_scan_fwd_kernel_wip( # Pointers to matrices cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, # Matrix dimensions chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, # Strides stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_D_head, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_Z: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_n = tl.program_id(axis=0) cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head offs_m = tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) dt_ptrs = dt_ptr + offs_m * stride_dt_csize out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) # if pid_c == 0: # if pid_b == 0: # if pid_h == 0: # tl.device_print("", prev_states) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # scale_m = tl.exp(dA_cs_m) # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # cb *= dt_m # mask = offs_m[:, None] >= offs_m[None, :] # cb = tl.where(mask, cb, 0.0) # cb = cb.to(x_ptr.dtype.element_ty) # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) # acc += tl.dot(cb, x) # if HAS_D: # if D_HAS_HDIM: # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) # else: # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) # acc += x.to(tl.float32) * D # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) if not HAS_SEQ_IDX: scale_m = tl.exp(dA_cs_m) else: scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) # cb *= dt_m # mask = offs_m[:, None] >= offs_m[None, :] # cb = tl.where(mask, cb, 0.0) # cb = cb.to(x_ptr.dtype.element_ty) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) # acc += tl.dot(cb, x) if HAS_D: if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) acc += x.to(tl.float32) * D # if HAS_Z: # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) # acc *= z * tl.sigmoid(z) tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) # TODO: this is not correct, and quite a bit slower if start_m + BLOCK_SIZE_M < chunk_size_limit: # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) # TODO: seq_idx scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m # B *= scale B = B.to(x_ptr.dtype.element_ty) tmp = tl.dot(B, x) prev_states += tmp.to(prev_states.dtype) C_ptrs += BLOCK_SIZE_M * stride_C_seqlen B_ptrs += BLOCK_SIZE_M * stride_B_seqlen cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k x_ptrs += BLOCK_SIZE_M * stride_x_seqlen dt_ptrs += BLOCK_SIZE_M * stride_dt_csize out_ptrs += BLOCK_SIZE_M * stride_out_seqlen @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), ], key=["chunk_size", "hdim"], ) @triton.jit def _chunk_scan_bwd_dz_kernel( # Pointers to matrices dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_D_head, stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, HAS_DDACS: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head if RECOMPUTE_OUTPUT: outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head if HAS_DDACS: ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head if HAS_D: x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) if RECOMPUTE_OUTPUT: outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) if HAS_D: x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) if D_HAS_HDIM: dD_ptrs = dD_ptr + offs_n * stride_dD_hdim chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) z_sigmoid = tl.sigmoid(z) if RECOMPUTE_OUTPUT: outz = out * z * z_sigmoid tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) dout *= z * z_sigmoid tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) if HAS_D: x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: dD = tl.sum(dout * x, axis=0) tl.store(dD_ptrs, dD, mask=offs_n < hdim) D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: dD = tl.sum(dout * x) tl.store(dD_ptr, dD) D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) out -= x * D if HAS_DDACS: ddA_cs = tl.sum(dout * out, axis=1) tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), ], key=['hdim', 'dstate', 'chunk_size'], ) @triton.jit def _chunk_scan_bwd_dstates_kernel( # Pointers to matrices dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, # Matrix dimensions hdim, dstate, chunk_size, batch, seqlen, nchunks, nheads_ngroups_ratio, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: scale_k = tl.exp(dA_cs_k) else: seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) acc += tl.dot(dout, c) dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen c_ptrs += BLOCK_SIZE_K * stride_c_seqlen dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize if HAS_SEQ_IDX: seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen out = acc.to(dprev_states_ptr.dtype.element_ty) dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), ], key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit def _chunk_scan_bwd_dc_kernel( # Pointers to matrices dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, dc_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, dstate, hdim, batch, seqlen, nheads, nheads_per_program, ngroups, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_sg = tl.program_id(axis=2) pid_s = pid_sg // ngroups pid_g = pid_sg - pid_s * ngroups num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize if HAS_DDA_CS: C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_DDA_CS: c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) for h in range(nheads_iter): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) prev_states = prev_states.to(dout_ptrs.dtype.element_ty) dc = tl.dot(dout, prev_states) dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: scale = tl.exp(dA_cs_m) else: scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) dc *= scale[:, None] if HAS_DDA_CS: ddA_cs = tl.sum(dc * c, axis=1) tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) acc += dc dout_ptrs += stride_dout_head prev_states_ptrs += stride_prev_states_head dA_cumsum_ptrs += stride_dA_cs_head if HAS_DDA_CS: ddA_cumsum_ptrs += stride_ddA_cs_head # if HAS_SEQ_IDX: # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), ], key=['chunk_size', 'hdim'], ) @triton.jit def _chunk_scan_bwd_dx_kernel( # Pointers to matrices x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, dx_ptr, ddt_ptr, # dD_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_D_head, stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head # if HAS_D: # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Idk why limiting K_MAX gives wrong results, is it a Triton bug? # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) K_MAX = chunk_size_limit for k in range(0, K_MAX, BLOCK_SIZE_K): # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. # This will cause NaN in acc, and hence NaN in dx and ddt. mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) cb = tl.where(mask, cb, 0.0) cb = cb.to(dout_ptr.dtype.element_ty) acc += tl.dot(cb, dout) cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dt_ptrs = dt_ptr + offs_m * stride_dt_csize dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dx = acc * dt_m[:, None] dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) if HAS_D: dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) dx += dout_res * D tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) ddt = tl.sum(acc * x, axis=1) ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) # if HAS_D: # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) # dD = tl.sum(x * dout, axis=0) # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) # Disabling HAS_DDA_CS for now since it's much slower @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), ], key=['chunk_size', 'hdim'], ) # @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) # @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) @triton.jit def _chunk_scan_bwd_dcb_kernel( # Pointers to matrices x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, dcb_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads, nheads_per_program, ngroups, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, # Meta-parameters HAS_DDA_CS: tl.constexpr, HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_sg = tl.program_id(axis=2) pid_s = pid_sg // ngroups pid_g = pid_sg - pid_s * ngroups num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head if HAS_DDA_CS: cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) dt_ptrs = dt_ptr + offs_n * stride_dt_csize if HAS_DDA_CS: cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) return chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if HAS_DDA_CS: cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) for h in range(nheads_iter): dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) dcb = tl.dot(dout, x) dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) dcb *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) if HAS_DDA_CS: tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") ddA_cs = dcb * cb mask = offs_m[:, None] >= offs_n[None, :] + 1 ddA_cs = tl.where(mask, ddA_cs, 0.0) ddA_cs = tl.cumsum(ddA_cs, axis=1) ddA_cs = tl.where(mask, ddA_cs, 0.0) ddA_cs = tl.sum(ddA_cs, axis=0) tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) tl.store(ddA_cumsum_ptr, 0.0) acc += dcb dout_ptrs += stride_dout_head x_ptrs += stride_x_head dt_ptrs += stride_dt_head dA_cumsum_ptr += stride_dA_cs_head if HAS_DDA_CS: ddA_cumsum_ptr += stride_ddA_cs_head ddA_cumsum_ptrs += stride_ddA_cs_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_SEQ_IDX: seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) mask = offs_m[:, None] >= offs_n[None, :] acc = tl.where(mask, acc, 0.0) dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) # Not numerically stable and should not be used. Leaving here for reference. @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32}), triton.Config({'BLOCK_SIZE_M': 64}), triton.Config({'BLOCK_SIZE_M': 128}), triton.Config({'BLOCK_SIZE_M': 256}), ], key=["chunk_size", "hdim"], ) @triton.jit def _chunk_scan_bwd_ddAcs_unstable_kernel( # Pointers to matrices dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, ddA_cumsum_ptr, dD_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_D_head, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, # Meta-parameters HAS_D: tl.constexpr, D_HAS_HDIM: tl.constexpr, SUBTRACT_DDTDT: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head if HAS_D: x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) if HAS_D: x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) if D_HAS_HDIM: dD_ptrs = dD_ptr + offs_n * stride_dD_hdim chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if HAS_D: x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) if D_HAS_HDIM: dD = tl.sum(dout * x, axis=0) tl.store(dD_ptrs, dD, mask=offs_n < hdim) D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) else: dD = tl.sum(dout * x) tl.store(dD_ptr, dD) D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) out -= x * D ddA_cs = tl.sum(dout * out, axis=1) if SUBTRACT_DDTDT: dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) ddA_cs -= dt * ddt tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) @triton.autotune( configs=[ # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), ], key=['chunk_size', 'hdim'], ) @triton.jit def _chunk_scan_bwd_ddAcs_stable_kernel_old( # Pointers to matrices x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, ddAcs_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) dt_ptrs = dt_ptr + offs_n * stride_dt_csize cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) # Doing a matmul loop with cumsum later on will cause Triton to crash # Instead we do just one big matmul # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # for k in range(0, hdim, BLOCK_SIZE_K): # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) # acc += tl.dot(dout, x) # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim # x_ptrs += BLOCK_SIZE_K * stride_x_hdim dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) acc = tl.dot(dout, x) cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) acc *= cb dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) acc *= dt_n dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) mask = offs_m[:, None] >= offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) acc = tl.cumsum(acc, axis=1) acc = tl.where(mask, acc, 0.0) ddA_cs = tl.sum(acc, axis=0) ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) tl.store(ddAcs_ptr, 0.0) # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) # offs_k = tl.arange(0, BLOCK_SIZE_K) # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) # dt_ptrs = dt_ptr + offs_n * stride_dt_csize # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n # for n in range(0, chunk_size_limit_n, 64): # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) # acc = tl.dot(dout, x) # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) # acc *= cb # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) # acc *= dt_n # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n # acc = tl.where(mask, acc, 0.0) # acc = tl.cumsum(acc, axis=1) # acc = tl.where(mask, acc, 0.0) # ddA_cs = tl.sum(acc, axis=0) # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) # # tl.store(ddAcs_ptr, 0.0) @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), ], key=['chunk_size', 'hdim'], ) @triton.jit def _chunk_scan_bwd_ddAcs_stable_kernel( # Pointers to matrices x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, hdim, batch, seqlen, nheads_ngroups_ratio, # Strides stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) dt_ptrs = dt_ptr + offs_n * stride_dt_csize cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n tl.store(ddA_cumsum_ptr, 0.0) chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M # lo, hi = 0, chunk_size for start_n in range(lo, hi, BLOCK_SIZE_N): start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) # Doing a matmul loop with cumsum later on will cause Triton to crash # Instead we do just one big matmul # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # for k in range(0, hdim, BLOCK_SIZE_K): # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) # acc += tl.dot(dout, x) # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim # x_ptrs += BLOCK_SIZE_K * stride_x_hdim # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) acc = tl.dot(dout, x) dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) acc *= dt_n # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) acc *= cb dA_cs_n = tl.load(dA_cumsum_ptr + (start_n + offs_n) * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 acc = tl.where(mask, acc, 0.0) rowsum_new = rowsum + tl.sum(acc, axis=1) acc = rowsum[:, None] + tl.cumsum(acc, axis=1) rowsum = rowsum_new acc = tl.where(mask, acc, 0.0) ddA_cs = tl.sum(acc, axis=0) tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) x_ptrs += BLOCK_SIZE_N * stride_x_seqlen dt_ptrs += BLOCK_SIZE_N * stride_dt_csize cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n # Need to zero out the rest, since we'll be summing the rows together for start_n in range(hi, chunk_size, BLOCK_SIZE_N): tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), ], key=['chunk_size', 'dstate', 'hdim'], ) @triton.jit def _chunk_scan_bwd_ddAcs_prev_kernel( # Pointers to matrices dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, ddA_cumsum_ptr, # Matrix dimensions chunk_size, dstate, hdim, batch, seqlen, nchunks, nheads_ngroups_ratio, # Strides stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, # Meta-parameters HAS_SEQ_IDX: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): pid_bc = tl.program_id(axis=1) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) prev_states = prev_states.to(dout_ptrs.dtype.element_ty) acc = tl.dot(dout, prev_states) c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) ddA_cs = tl.sum(acc * c, axis=1) dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) if not HAS_SEQ_IDX: scale = tl.exp(dA_cs_m) if HAS_SEQ_IDX: seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) ddA_cs *= scale offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Allocates output. out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) if z is not None: out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) assert out_x.stride() == out.stride() else: out_x = None grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), x.stride(0), x.stride(1), x.stride(2), x.stride(3), z_strides[0], z_strides[1], z_strides[2], z_strides[3], out.stride(0), out.stride(1), out.stride(2), out.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), C.stride(0), C.stride(1), C.stride(2), C.stride(3), states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), D.stride(0) if D is not None else 0, True, D is not None, D.dim() == 2 if D is not None else True, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), HAS_Z=z is not None, HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, ) return out, out_x def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) assert B.shape == C.shape assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) # Allocates output. out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) if z is not None: out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) assert out_x.stride() == out.stride() else: out_x = None grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel_wip[grid]( cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, chunk_size, headdim, dstate, batch, seqlen, nheads // ngroups, cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), x.stride(0), x.stride(1), x.stride(2), x.stride(3), z_strides[0], z_strides[1], z_strides[2], z_strides[3], out.stride(0), out.stride(1), out.stride(2), out.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), C.stride(0), C.stride(1), C.stride(2), C.stride(3), B.stride(0), B.stride(1), B.stride(2), B.stride(3), states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), D.stride(0) if D is not None else 0, D is not None, D.dim() == 2 if D is not None else True, BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), BLOCK_SIZE_M=128, HAS_Z=z is not None, HAS_SEQ_IDX=seq_idx is not None, ) return out, out_x def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): batch, seqlen, nheads, headdim = x.shape assert z.shape == x.shape assert out.shape == x.shape assert dout.shape == out.shape nchunks = math.ceil(seqlen / chunk_size) if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert D.stride(-1) == 1 if has_ddAcs: ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) if D is not None: BLOCK_SIZE_min = 32 dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) else: dD = None if dz is not None: assert dz.shape == z.shape else: dz = torch.empty_like(z) if recompute_output: outz = torch.empty_like(x) dout_x = torch.empty_like(dout) dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) if D is not None else (0, 0, 0, 0, 0)) grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_dz_kernel[grid_dz]( dout, out, z, x, D, outz if recompute_output else None, dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, chunk_size, headdim, batch, seqlen, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), z.stride(0), z.stride(1), z.stride(2), z.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), D.stride(0) if D is not None else 0, *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) if has_ddAcs else (0, 0, 0, 0)), D is not None, D.dim() == 2 if D is not None else True, has_ddAcs, BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), RECOMPUTE_OUTPUT=recompute_output, ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) if D.dim() == 1: dD = rearrange(dD, "h 1 -> h") return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) return return_vals if not recompute_output else (*return_vals, outz) def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): batch, seqlen, nheads, headdim = dout.shape _, _, nchunks, chunk_size = dA_cumsum.shape _, _, ngroups, dstate = C.shape assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) dtype = C.dtype if dtype is None else dtype dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(C.device.index): _chunk_scan_bwd_dstates_kernel[grid_dstates]( dout, C, dprev_states, dA_cumsum, seq_idx, headdim, dstate, chunk_size, batch, seqlen, nchunks, nheads // ngroups, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), C.stride(0), C.stride(1), C.stride(2), C.stride(3), dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), HAS_SEQ_IDX=seq_idx is not None, ) return dprev_states def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): batch, nchunks, nheads, headdim, dstate = prev_states.shape _, seqlen, _, _ = dout.shape _, _, _, chunk_size = dA_cumsum.shape assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == (batch, seqlen, nheads, headdim) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if C is not None: assert C.shape == (batch, seqlen, ngroups, dstate) C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) else: C_strides = (0, 0, 0, 0) ddA_cumsum_prev = None ddA_cumsum_prev_strides = (0, 0, 0, 0) nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nsplits * ngroups) with torch.cuda.device(dout.device.index): _chunk_scan_bwd_dc_kernel[grid_dc]( dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, chunk_size, dstate, headdim, batch, seqlen, nheads, nheads_per_program, ngroups, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), *C_strides, dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), *ddA_cumsum_prev_strides, HAS_DDA_CS=ddA_cumsum_prev is not None, HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dC = dC.sum(2) return dC if C is None else (dC, ddA_cumsum_prev) def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dout.shape == x.shape if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) if CB is not None: assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) BLOCK_SIZE_M_min = 16 ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size, device=x.device, dtype=torch.float32) ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) else: CB_strides = (0, 0, 0, 0, 0) ddA_cumsum = None ddA_cumsum_strides = (0, 0, 0, 0, 0) nheads_ngroups_ratio = nheads // ngroups sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), batch * nchunks, nsplits * ngroups) with torch.cuda.device(x.device.index): _chunk_scan_bwd_dcb_kernel[grid_dcb]( x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, chunk_size, headdim, batch, seqlen, nheads, nheads_per_program, ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), *CB_strides, dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), *ddA_cumsum_strides, HAS_DDA_CS=ddA_cumsum is not None, HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) dcb = dcb.sum(2) if ddA_cumsum is not None: BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) return dcb if CB is None else (dcb, ddA_cumsum) def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape ngroups = cb.shape[2] assert nheads % ngroups == 0 assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape assert dout.shape == x.shape # if D is not None: # BLOCK_SIZE_M_min = 32 # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) # else: # dD = None dx = torch.empty_like(x) ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_dx_kernel[grid_dx]( x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, chunk_size, headdim, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), D.stride(0) if D is not None else 0, dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, D is not None, D.dim() == 2 if D is not None else True, ) # if D is not None: # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) return dx, ddt.to(dtype=dt.dtype) def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): """Not numerically stable and should not be used. Leaving here for reference. """ batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert ddt.shape == dt.shape assert out.shape == x.shape assert dout.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) ddA_cumsum = torch.empty_like(dt) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) if D is not None: # Triton gives wrong results if we write to the same location BLOCK_SIZE_min = 32 dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) else: dD = None dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) if D is not None else (0, 0, 0, 0, 0)) with torch.cuda.device(x.device.index): _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( dout, out, dt, ddt, x, D, ddA_cumsum, dD, chunk_size, headdim, batch, seqlen, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), x.stride(0), x.stride(1), x.stride(2), x.stride(3), D.stride(0) if D is not None else 0, ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], D is not None, D.dim() == 2 if D is not None else True, subtract_ddtdt, BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), ) if D is not None: BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) if D.dim() == 1: dD = rearrange(dD, "h 1 -> h") return ddA_cumsum, dD def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == x.shape assert dA_cumsum.shape == dt.shape ngroups = cb.shape[2] assert nheads % ngroups == 0 assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) BLOCK_SIZE_M_min = 16 ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size, device=x.device, dtype=torch.float32) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( x, dout, dt, dA_cumsum, cb, ddA_cumsum, chunk_size, headdim, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), ) BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) return ddA_cumsum def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == x.shape assert dA_cumsum.shape == dt.shape ngroups = cb.shape[2] assert nheads % ngroups == 0 assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) BLOCK_SIZE_M_min = 32 ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), chunk_size, device=x.device, dtype=torch.float32) grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) with torch.cuda.device(x.device.index): _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( x, dout, dt, dA_cumsum, cb, ddA_cumsum, chunk_size, headdim, batch, seqlen, nheads // ngroups, x.stride(0), x.stride(1), x.stride(2), x.stride(3), dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) return ddA_cumsum def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): batch, nchunks, nheads, headdim, dstate = prev_states.shape _, seqlen, _, _ = dout.shape _, _, _, chunk_size = dA_cumsum.shape assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert dout.shape == (batch, seqlen, nheads, headdim) ngroups = C.shape[2] assert nheads % ngroups == 0 assert C.shape == (batch, seqlen, ngroups, dstate) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) with torch.cuda.device(dout.device.index): _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, chunk_size, dstate, headdim, batch, seqlen, nchunks, nheads // ngroups, dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), C.stride(0), C.stride(1), C.stride(2), C.stride(3), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), HAS_SEQ_IDX=seq_idx is not None, BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), ) return ddA_cumsum_prev class ChunkScanFn(torch.autograd.Function): @staticmethod def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): # Check constraints. batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert B.shape == (batch, seqlen, ngroups, dstate) _, _, nchunks, chunk_size = dt.shape assert seqlen == nchunks * chunk_size assert C.shape == B.shape if z is not None: assert z.shape == x.shape if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads,) assert dt.shape == (batch, nheads, nchunks, chunk_size) assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous x = x.contiguous() if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() CB = _bmm_chunk_fwd(C, B, chunk_size) out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) return out @staticmethod def backward(ctx, dout): if dout.stride(-1) != 1: dout = dout.contiguous() out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = B.shape assert dout.shape == (batch, seqlen, nheads, headdim) if z is not None: dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) else: dz = None dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) dC = dC.to(C.dtype) dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) dCB = dCB.to(CB.dtype) dB = _bmm_chunk_bwd(C, dCB) dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt if z is not None: ddA_cumsum -= ddt * dt else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): """ prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. Argument: B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) prev_states: (batch, nchunks, nheads, headdim, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): """ Argument: B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) prev_states: (batch, nchunks, nheads, headdim, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape assert B.shape == (batch, seqlen, ngroups, dstate) _, _, nchunks, chunk_size = dt.shape assert seqlen == nchunks * chunk_size assert C.shape == B.shape B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) # (batch, nheads, nchunks, chunksize, chunksize) dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: if D.dim() == 1: D = rearrange(D, "h -> h 1") out = out + x * D return out if z is None else out * F.silu(z)