triton_kernels / tests /test_matmul.py
marcsun13's picture
marcsun13 HF Staff
Upload folder using huggingface_hub
567c8ad verified
# isort: off
# fmt: off
from dataclasses import dataclass, fields, replace
import pytest
import torch
from typing import Union
import triton
# routing utilities
from triton_kernels.routing import routing
# matmul utilities
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4
from triton_kernels.tensor_details import layout
# numerics utilities
from triton_kernels.numerics import InFlexData, OutFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, dequantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE
# testing utilities
from triton_kernels.testing import assert_close, compute_actual_scale
# target-specific utilities
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
# ---------------
# initialize data
# ---------------
def alloc_rand(shape, device, dtype, requires_grad=True):
if dtype.itemsize == 1:
tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16))
return tmp.to(dtype).requires_grad_(requires_grad)
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
def alloc_rand_like(x):
return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad)
def mask_indx(idx, n_expts_act):
idx.src_indx[idx.dst_indx[-n_expts_act:]] = -1
idx.dst_indx[-n_expts_act:] = -1
return idx
def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"):
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True)
routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards)
routing_data.gate_scal = None
gather_idx = gather_idx if do_gather else None
scatter_idx = scatter_idx if do_scatter else None
return m, routing_data, gather_idx, scatter_idx
def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype,
has_y_gammas, requires_grad=True, device="cuda"):
torch.manual_seed(0)
assert mode in {'batched', "plain", 'ragged'}
in_m = m * (n_expts_act if gindx is None else 1)
shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k)
shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, )
x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad)
w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
gs0 = gs0.detach().requires_grad_(requires_grad)
gs1 = gs1.detach().requires_grad_(requires_grad)
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
gs0 = None
gs1 = None
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
return x, w, bias, gs0, gs1
# ---------------
# numerics stuff
# ---------------
def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, device="cuda"):
weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp
# flexpoint
make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) +
([val0]
if n_expts_tot % 2 else []), dtype=torch.float32, device=device)
make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device)
in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale)
) if use_flex else InFlexData()
in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=weight_dtype, scale=make_tensor(scale0, scale1)
) if use_flex else InFlexData()
out_flex_data = lambda scale, use_flex: OutFlexData(dtype=out_dtype, expected_scale=make_scalar(
scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData()
flex_ctx = FlexCtx(
lhs_data=in_flex_data(1.25, act_use_flexpoint),
rhs_data=in_flex_edata(1.50, 1.25, weight_use_flexpoint),
out_data=out_flex_data(4.00, act_use_flexpoint),
)
return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0,
out_dtype=out_dtype)
def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
flex_ctx = precision_config.flex_ctx
def apply(x, scale):
if scale is None:
x = x.clone()
elif scale.numel() == 1:
x = x.float() * scale
else:
assert x.ndim == 3
assert scale.numel() == x.shape[0]
x = x.float() * scale[:, None, None]
return x.detach().requires_grad_()
return (
apply(x_tri, flex_ctx.lhs_data.scale),
apply(w_tri, flex_ctx.rhs_data.scale),
apply(bias_tri, None),
None if gs0_tri is None else apply(gs0_tri, None),
None if gs1_tri is None else apply(gs1_tri, None),
)
def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str)
# Scope to ensure that the opt_flags_constraints are reset after the test
@pytest.fixture
def opt_flags_scope(request):
yield
opt_flags.reset_opt_flags_constraints()
# ---------------
# unit tests
# ---------------
@dataclass
class Case:
m: int
n: int
k: int
mode: str
act_dtype_str: str
weight_dtype_str: str
n_expts_tot: int = 1
n_expts_act: int = 1
n_expt_shards: int = 1
split_k: int = 1
hbm_swizzling: bool = False
epilogue_subtile: Union[int, None] = None
@pytest.mark.parametrize(
", ".join(f.name for f in fields(Case)),
[
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
# Non-mx types:
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4),
Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2),
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1),
Case(16, 256, 256, "batched", "float16", "float16", 5, 1),
Case(16, 256, 256, "ragged", "float16", "float16", 3, 1),
Case(256, 256, 256, "ragged", "float16", "float16", 4, 1),
Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3),
Case(300, 400, 400, "batched", "float16", "float16", 5, 1),
Case(300, 400, 400, "ragged", "float16", "float16"),
Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"),
Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2),
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2),
Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1),
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
# mx types:
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4),
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False),
Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4),
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4),
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
# AMD
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2),
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2),
Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"),
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
]
],
)
@pytest.mark.parametrize("block_m", [16, 128])
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
(False, False, False),
(True, False, False),
(False, True, False),
(True, True, False),
(True, True, True),
])
@pytest.mark.parametrize("has_y_gammas", [False, True])
@pytest.mark.parametrize("is_persistent", [False, True])
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
device, opt_flags_scope, fresh_knobs):
# TODO: remove when Triton FP8 supports proper RTNE
if is_cuda():
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("Float8 not tested on A100")
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
pytest.skip("float16 x mx not supported with cuda capability >= 10")
if weight_dtype_str.startswith("mx"):
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:
pytest.skip("float8 x mx not supported with cuda capability < 10")
if act_dtype_str == "mxfloat8_e4m3fn":
if is_persistent:
pytest.skip("mx x mx not supported with persistent kernel")
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("Not enough memory on A100")
elif is_hip():
if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4():
pytest.skip("float8 x mx only supported on CDNA4")
if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str:
pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU")
if act_dtype_str.startswith("mx") and weight_dtype_str.startswith("mx"):
pytest.skip("NYI: mx x mx not tested on AMD GPU")
if is_persistent:
pytest.skip("NYI: Persistent kernel not supported on AMD GPU")
if split_k > 1:
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
if fused_scatter and split_k > 1:
pytest.skip("fused scatter scratchpad not supported with split_k")
if hbm_swizzling:
if is_hip():
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
if torch.cuda.get_device_capability()[0] < 9:
pytest.skip("NYI. Ampere swizzling.")
if torch.cuda.get_device_capability()[0] < 10:
if "mxfloat4" not in weight_dtype_str:
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
if k % 64 != 0 or n % 64 != 0:
# Automatic padding not implemented for Hopper swizzle
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
# launch metadata for batched / mx types may not work yet.
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
torch.manual_seed(0)
block_k = None
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
# Override block_k for testing correctness. The default is temporarily 128 for
# performance reasons which doesn't work with persistent matmul.
# TODO: revisit when Triton is better for H100 + MXFP4
block_k = 256
constraints = {
"block_m": block_m,
"block_k": block_k,
"split_k": split_k,
"fused_scatter": fused_scatter,
"is_persistent": is_persistent,
"epilogue_subtile": epilogue_subtile,
}
opt_flags.update_opt_flags_constraints(constraints)
weight_mxfp = weight_dtype_str.startswith("mx")
if weight_mxfp:
weight_dtype_str = weight_dtype_str[2:]
act_mxfp8 = act_dtype_str.startswith("mx")
act_is_float8 = act_dtype_str.startswith("float8")
if act_mxfp8:
act_dtype_str = act_dtype_str[2:]
dequantize_mxfp8_spec = FnSpecs(
FnName.DEQUANTIZE_MXFP8.name, dequantize_mxfp8_fn, (), ()
)
test_bwd = False
weight_dtype = dtype_str_to_torch(weight_dtype_str)
act_dtype = dtype_str_to_torch(act_dtype_str)
precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, n_expts_tot // n_expt_shards, device=device)
# precision_opt.x_pad_trans_requires_flexpoint = False
if mode == "ragged":
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
device=device)
else:
rdata = gindx = sindx = None
x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act,
n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype, #
torch.bfloat16 if weight_mxfp else weight_dtype,
has_y_gammas, requires_grad=test_bwd, device=device)
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)
if w_tri.shape[0] == 1:
# Test the case when weight has dim 2, i.e., shape (K, N).
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)
if weight_mxfp:
mx_axis = w_tri.ndim - 2
# compute layouts
w_layout, w_layout_opts = layout.StridedLayout, dict()
w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict()
if hbm_swizzling and "float4" in weight_dtype_str:
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis)
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
mx_axis=mx_axis, num_warps=8)
# downcast to mxfp
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
w_scale_tri = wrap_torch_tensor(w_scale_tri)
# convert layouts
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
precision_opt.weight_scale = w_scale_tri
epilogue = None
if act_mxfp8:
x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1)
x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1)
is_input_batched = x_tri.ndim == 3
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
if sindx is None or mode == "batched":
if not is_input_batched:
y_shape = (y_shape[1], y_shape[2])
else:
y_shape = (n_rows // rdata.n_expts_act, y_shape[-1])
y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),)
y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device)
precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale)
epilogue = Epilogue(dequantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0)
else:
y_scale = None
if test_launch_metadata:
def _clobber(t, used_mask):
# Fill the unread part of the tensor with garbage, to be sure that
# we don't actually read from the part.
if len(used_mask) == 1:
return
elif t.element_size() == 1:
t.view(torch.int8)[~used_mask] = 127
else:
t[~used_mask] = torch.inf
if rdata is not None:
n_tokens = rdata.expt_hist.sum().item()
used_expts = (rdata.expt_hist > 0)
_clobber(w_tri, used_expts)
n_w_bytes = used_expts.sum().item() * n * k * w_tri.element_size()
else:
n_tokens = m
n_w_bytes = w_tri.numel() * w_tri.element_size()
if gindx is not None:
used_x_rows = (gindx.dst_indx.view(-1, n_expts_act) != -1).any(dim=1)
_clobber(x_tri, used_x_rows)
n_x_bytes = used_x_rows.sum().item() * k * x_tri.element_size()
elif rdata is not None:
n_x_bytes = n_tokens * k * x_tri.element_size()
else:
n_x_bytes = x_tri.numel() * x_tri.element_size()
nbytes = None
def _hook(launch_metadata):
nonlocal nbytes
metadata = launch_metadata.get()
if "matmul_ogs" in metadata["name"]:
nbytes = metadata["bytes"]
triton.knobs.runtime.launch_enter_hook = _hook
if mode == "batched":
rdata, gindx, sindx = None, None, None
flex = precision_opt.flex_ctx
# triton
try:
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue)
except (opt_flags.InapplicableConstraint, NotImplementedError):
pytest.skip("inapplicable opt_flags constraint")
# If split_k > 1, then the intermediate tensor is fp32.
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
y_scale = flex.out_data.expected_scale if act_is_float8 else 1
if test_launch_metadata:
if gindx is not None:
n_y_bytes = (gindx.src_indx != -1).sum().item() * n * tri_y.element_size()
elif rdata is not None:
n_y_bytes = n_tokens * n * tri_y.element_size()
else:
n_y_bytes = tri_y.numel() * tri_y.element_size()
assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes
triton.knobs.runtime.launch_enter_hook = None
def round_x(x, idx):
return x.to(act_dtype).to(torch.float32) if sep_gather else x
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
scale = lambda val, scal: val if scal is None else val / scal
if n_expt_shards > 1:
if do_scatter:
indx = sindx.dst_indx[sindx.dst_indx != -1]
ref_y = ref_y[indx // n_expts_act, :]
if act_is_float8:
tri_y = tri_y.view(torch.int8)
tri_y = tri_y[indx // n_expts_act, :]
if act_is_float8:
tri_y = tri_y.view(act_dtype)
else:
n_rows = rdata.expt_hist.sum()
assert n_rows > 0
ref_y = ref_y[:n_rows]
tri_y = tri_y[:n_rows]
if act_mxfp8:
tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, dtype=torch.bfloat16, axis=-1).to(ref_y.dtype)
ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1)
ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1)
maxtol = 4e-1
rmstol = 4e-2
else:
maxtol = None
rmstol = None
assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol)
if act_is_float8:
tri_y_scale = flex.out_data.actual_scale.clone()
ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype)
assert (ref_y_scale -
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
def test_set_idle_sms():
if not is_cuda():
pytest.skip("Only supported on CUDA")
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
num_idle_sms = 24
matmul_ogs_set_idle_sms(num_idle_sms)
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
1024, 1024, 1024, None, True, False, 1)
assert flags.idle_sms == num_idle_sms
@pytest.mark.parametrize("m, n, k, mode", [
(1200, 704, 608, "ragged"),
(800, 800, 400, "batched"),
])
@pytest.mark.parametrize("split_k", [1, 2])
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
(False, False, False),
(True, False, False),
(False, True, False),
(True, True, False),
(True, True, True),
])
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
(False, None),
(True, 1),
(True, 4),
])
@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [
(1.1, 1.4),
(1.0, 1.2),
(0.7, 1.0),
])
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile,
swiglu_alpha, swiglu_limit, device, opt_flags_scope):
if fused_scatter and split_k > 1:
pytest.skip("fused scatter scratchpad not supported with split_k")
torch.manual_seed(0)
constraints = {
"is_persistent": is_persistent,
"epilogue_subtile": epilogue_subtile,
"fused_scatter": fused_scatter,
"split_k": split_k,
}
n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1
opt_flags.update_opt_flags_constraints(constraints)
weight_dtype, act_dtype = torch.float16, torch.float16
if mode == "ragged":
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
device=device)
else:
rdata = gindx = sindx = None
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device)
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
act_dtype, weight_dtype, False, requires_grad=False, device=device)
if mode == "batched":
rdata, gindx, sindx = None, None, None
try:
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
b = matmul_ogs(
x, w, bias, rdata, gindx, sindx, precision_opt,
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit), 2))
except opt_flags.InapplicableConstraint:
pytest.skip("inapplicable constraint")
assert_close(a, b)