|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time |
|
|
|
from flash_attn.ops.triton.k_activations import ( |
|
gelu, |
|
gelu_approx, |
|
gelu_approx_grad, |
|
gelu_grad, |
|
squared_relu, |
|
squared_relu_grad, |
|
) |
|
|
|
|
|
|
|
|
|
def init_to_zero(name): |
|
return lambda nargs: nargs[name].zero_() |
|
|
|
|
|
def get_configs_io_bound(): |
|
configs = [] |
|
for num_stages in [2, 3, 4, 5, 6]: |
|
for block_m in [16, 32]: |
|
for block_k in [32, 64]: |
|
for block_n in [32, 64, 128, 256]: |
|
num_warps = 2 if block_n <= 64 else 4 |
|
configs.append( |
|
triton.Config( |
|
{ |
|
"BLOCK_M": block_m, |
|
"BLOCK_N": block_n, |
|
"BLOCK_K": block_k, |
|
"SPLIT_K": 1, |
|
}, |
|
num_stages=num_stages, |
|
num_warps=num_warps, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return configs |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
|
), |
|
|
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, |
|
num_stages=3, |
|
num_warps=8, |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
|
num_stages=3, |
|
num_warps=8, |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
|
num_stages=4, |
|
num_warps=4, |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
|
), |
|
] |
|
+ get_configs_io_bound(), |
|
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], |
|
prune_configs_by={ |
|
"early_config_prune": early_config_prune, |
|
"perf_model": estimate_matmul_time, |
|
"top_k": 10, |
|
}, |
|
) |
|
@triton.heuristics( |
|
{ |
|
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, |
|
} |
|
) |
|
@triton.jit |
|
def kernel_fwd( |
|
C, |
|
ACT_INPUT, |
|
A, |
|
B, |
|
bias, |
|
|
|
M, |
|
N, |
|
K, |
|
CACHE_KEY_M, |
|
CACHE_KEY_N, |
|
CACHE_KEY_K, |
|
|
|
|
|
|
|
stride_cm, |
|
|
|
stride_am, |
|
stride_ak, |
|
stride_bn, |
|
stride_bk, |
|
|
|
BLOCK_M: tl.constexpr, |
|
GROUP_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
BLOCK_K: tl.constexpr, |
|
|
|
SPLIT_K: tl.constexpr, |
|
EVEN_K: tl.constexpr, |
|
A_ROWMAJOR: tl.constexpr, |
|
B_COLMAJOR: tl.constexpr, |
|
BIAS: tl.constexpr, |
|
SAVE_ACT_INPUT: tl.constexpr, |
|
ACTIVATION: tl.constexpr, |
|
): |
|
|
|
""" |
|
Kernel for computing Out = activation(A x W + C) |
|
- Input has shape (M, K) |
|
- Weight has shape (K, N) |
|
- Bias has shape (N,) |
|
- Output has shape (M, N) |
|
- ActInputs (optional) has shape (M, N) |
|
'ActInputs' optionally saves the A x W + C intermediate for backward computations |
|
This kernel will consolidate over K |
|
""" |
|
|
|
pid = tl.program_id(axis=0) |
|
|
|
grid_m = (M + BLOCK_M - 1) // BLOCK_M |
|
grid_n = (N + BLOCK_N - 1) // BLOCK_N |
|
|
|
width = GROUP_M * grid_n |
|
group_id = pid // width |
|
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) |
|
pid_m = group_id * GROUP_M + (pid % group_size) |
|
pid_n = (pid % width) // (group_size) |
|
|
|
|
|
|
|
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
|
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) |
|
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) |
|
rk = tl.arange(0, BLOCK_K) |
|
|
|
if A_ROWMAJOR: |
|
A = A + (ram[:, None] * stride_am + rk[None, :]) |
|
else: |
|
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) |
|
if B_COLMAJOR: |
|
B = B + (rk[:, None] + rbn[None, :] * stride_bn) |
|
else: |
|
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) |
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
|
for k in range(K, 0, -BLOCK_K): |
|
if EVEN_K: |
|
a = tl.load(A) |
|
b = tl.load(B) |
|
else: |
|
a = tl.load(A, mask=rk[None, :] < k, other=0.0) |
|
b = tl.load(B, mask=rk[:, None] < k, other=0.0) |
|
acc += tl.dot(a, b) |
|
|
|
if A_ROWMAJOR: |
|
A += BLOCK_K |
|
else: |
|
A += BLOCK_K * stride_ak |
|
if B_COLMAJOR: |
|
B += BLOCK_K |
|
else: |
|
B += BLOCK_K * stride_bk |
|
|
|
|
|
if BIAS: |
|
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) |
|
acc += bias[None, :] |
|
|
|
|
|
if SAVE_ACT_INPUT: |
|
|
|
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] |
|
tl.store(act_in_ptrs, acc) |
|
|
|
|
|
if ACTIVATION == "gelu": |
|
acc = gelu(acc) |
|
elif ACTIVATION == "gelu_approx": |
|
acc = gelu_approx(acc) |
|
elif ACTIVATION == "squared_relu": |
|
acc = squared_relu(acc) |
|
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
|
|
|
|
|
C = C + rm[:, None] * stride_cm + rn[None, :] |
|
mask = (rm < M)[:, None] & (rn < N)[None, :] |
|
tl.store(C, acc) |
|
|
|
|
|
def triton_linear_act( |
|
x: torch.Tensor, |
|
weight: torch.Tensor, |
|
bias: Optional[torch.Tensor] = None, |
|
activation: str = "id", |
|
save_act_input: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Compute e = activation(x @ weight.T + bias). |
|
This wrapper kicks the `kernel_fwd` Triton kernel |
|
:param x: input tensor |
|
:param weight: weight matrix |
|
:param bias: an optional bias tensor |
|
:param activation: Activation name. Needs to be a Triton kernel. |
|
:param act_input: an optional tensor to save the activation inputs (for backward) |
|
:return: result tensor |
|
""" |
|
|
|
|
|
|
|
|
|
assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] |
|
|
|
batch_shape, n = x.shape[:-1], x.shape[-1] |
|
batch_dim = batch_shape.numel() |
|
x_reshaped = x.reshape(batch_dim, n) |
|
|
|
if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: |
|
x_reshaped = x_reshaped.contiguous() |
|
if weight.stride(0) > 1 and weight.stride(1) > 1: |
|
weight = weight.contiguous() |
|
bias = bias.contiguous() if bias is not None else None |
|
|
|
assert ( |
|
x.dtype == weight.dtype |
|
), f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" |
|
if bias is not None: |
|
assert ( |
|
x.dtype == bias.dtype |
|
), f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" |
|
assert ( |
|
x_reshaped.shape[1] == weight.shape[1] |
|
), f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" |
|
|
|
assert ( |
|
bias is None or bias.shape[0] == weight.shape[0] |
|
), "Incompatible dimensions in between weight and bias" |
|
|
|
M, K = x_reshaped.shape |
|
N, K = weight.shape |
|
|
|
output = torch.empty((M, N), device=x.device, dtype=x.dtype) |
|
act_input = torch.empty_like(output) if save_act_input else None |
|
|
|
|
|
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) |
|
|
|
kernel_fwd[grid]( |
|
output, |
|
act_input, |
|
x_reshaped, |
|
weight, |
|
bias if bias is not None else x, |
|
M, |
|
N, |
|
K, |
|
M // 32, |
|
N // 32, |
|
K // 32, |
|
stride_cm=output.stride(0), |
|
|
|
stride_am=x_reshaped.stride(0), |
|
stride_ak=x_reshaped.stride(1), |
|
stride_bk=weight.stride(1), |
|
stride_bn=weight.stride(0), |
|
BIAS=bias is not None, |
|
SAVE_ACT_INPUT=save_act_input, |
|
ACTIVATION=activation, |
|
A_ROWMAJOR=x_reshaped.stride(1) == 1, |
|
B_COLMAJOR=weight.stride(1) == 1, |
|
GROUP_M=8, |
|
) |
|
|
|
if not save_act_input: |
|
return output.reshape(*batch_shape, output.shape[-1]) |
|
else: |
|
return ( |
|
output.reshape(*batch_shape, output.shape[-1]), |
|
act_input.reshape(*batch_shape, act_input.shape[-1]), |
|
) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
|
), |
|
|
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, |
|
num_stages=3, |
|
num_warps=8, |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
|
num_stages=3, |
|
num_warps=8, |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
|
num_stages=4, |
|
num_warps=4, |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4 |
|
), |
|
triton.Config( |
|
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2 |
|
), |
|
] |
|
+ get_configs_io_bound(), |
|
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], |
|
prune_configs_by={ |
|
"early_config_prune": early_config_prune, |
|
"perf_model": estimate_matmul_time, |
|
"top_k": 10, |
|
}, |
|
) |
|
@triton.heuristics( |
|
{ |
|
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, |
|
} |
|
) |
|
@triton.jit |
|
def kernel_bwd( |
|
C, |
|
ACT_INPUT, |
|
A, |
|
B, |
|
|
|
M, |
|
N, |
|
K, |
|
CACHE_KEY_M, |
|
CACHE_KEY_N, |
|
CACHE_KEY_K, |
|
|
|
|
|
|
|
stride_cm, |
|
|
|
stride_am, |
|
stride_ak, |
|
stride_bk, |
|
stride_bn, |
|
|
|
BLOCK_M: tl.constexpr, |
|
GROUP_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
BLOCK_K: tl.constexpr, |
|
|
|
SPLIT_K: tl.constexpr, |
|
EVEN_K: tl.constexpr, |
|
ACTIVATION: tl.constexpr, |
|
): |
|
|
|
""" |
|
Kernel for computing Out = activation(A x W + C) |
|
- Input has shape (M, K) |
|
- Weight has shape (K, N) |
|
- Output has shape (M, N) |
|
- ActInputs (optional) has shape (M, N) |
|
'ActInputs' optionally saves the A x W + C intermediate for backward computations |
|
This kernel will consolidate over K |
|
""" |
|
|
|
pid = tl.program_id(axis=0) |
|
|
|
grid_m = (M + BLOCK_M - 1) // BLOCK_M |
|
grid_n = (N + BLOCK_N - 1) // BLOCK_N |
|
|
|
width = GROUP_M * grid_n |
|
group_id = pid // width |
|
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) |
|
pid_m = group_id * GROUP_M + (pid % group_size) |
|
pid_n = (pid % width) // (group_size) |
|
|
|
|
|
|
|
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
|
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) |
|
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) |
|
rk = tl.arange(0, BLOCK_K) |
|
|
|
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) |
|
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) |
|
|
|
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
|
for k in range(K, 0, -BLOCK_K): |
|
if EVEN_K: |
|
a = tl.load(A) |
|
b = tl.load(B) |
|
else: |
|
a = tl.load(A, mask=rk[None, :] < k, other=0.0) |
|
b = tl.load(B, mask=rk[:, None] < k, other=0.0) |
|
acc += tl.dot(a, b) |
|
|
|
A += BLOCK_K * stride_ak |
|
B += BLOCK_K * stride_bk |
|
|
|
|
|
if ACTIVATION != "id": |
|
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] |
|
act_input = tl.load(act_in_ptrs).to(acc.dtype) |
|
if ACTIVATION == "gelu": |
|
acc *= gelu_grad(act_input) |
|
elif ACTIVATION == "gelu_approx": |
|
acc *= gelu_approx_grad(act_input) |
|
elif ACTIVATION == "squared_relu": |
|
acc *= squared_relu_grad(act_input) |
|
|
|
|
|
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
|
|
|
C = C + rm[:, None] * stride_cm + rn[None, :] |
|
mask = (rm < M)[:, None] & (rn < N)[None, :] |
|
tl.store(C, acc, mask=mask) |
|
|
|
|
|
def triton_dgrad_act( |
|
grad_output: torch.Tensor, |
|
weight: torch.Tensor, |
|
activation: str = "id", |
|
act_input: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute e = activation(grad_output @ weight + bias). |
|
This wrapper kicks the `kernel_fwd` Triton kernel |
|
:param grad_output: input tensor |
|
:param weight: weight matrix |
|
:param activation: Activation name. Needs to be a Triton kernel. |
|
:param act_input: an optional tensor to save the activation inputs (for backward) |
|
:return: result tensor |
|
""" |
|
assert activation in ["id", "gelu", "gelu_approx", "squared_relu"] |
|
|
|
batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] |
|
batch_dim = batch_shape.numel() |
|
grad_output_reshaped = grad_output.reshape(batch_dim, n) |
|
|
|
if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: |
|
grad_output_reshaped = grad_output_reshaped.contiguous() |
|
if weight.stride(0) > 1 and weight.stride(1) > 1: |
|
weight = weight.contiguous() |
|
|
|
assert ( |
|
grad_output.dtype == weight.dtype |
|
), f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" |
|
assert ( |
|
grad_output_reshaped.shape[1] == weight.shape[0] |
|
), f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" |
|
if activation != "id": |
|
assert act_input is not None, f"act_input is required for activation {activation}" |
|
|
|
|
|
M, K = grad_output_reshaped.shape |
|
K, N = weight.shape |
|
|
|
grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) |
|
|
|
|
|
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) |
|
|
|
kernel_bwd[grid]( |
|
grad_input, |
|
act_input, |
|
grad_output_reshaped, |
|
weight, |
|
M, |
|
N, |
|
K, |
|
M // 32, |
|
N // 32, |
|
K // 32, |
|
stride_cm=grad_input.stride(0), |
|
|
|
stride_am=grad_output_reshaped.stride(0), |
|
stride_ak=grad_output_reshaped.stride(1), |
|
stride_bk=weight.stride(0), |
|
stride_bn=weight.stride(1), |
|
ACTIVATION=activation, |
|
GROUP_M=8, |
|
) |
|
|
|
return grad_input.reshape(*batch_shape, grad_input.shape[-1]) |
|
|