|
|
|
|
|
from dataclasses import dataclass, fields, replace |
|
import pytest |
|
import torch |
|
from typing import Union |
|
import triton |
|
|
|
from triton_kernels.routing import routing |
|
|
|
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags |
|
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue |
|
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch |
|
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig |
|
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4 |
|
from triton_kernels.tensor_details import layout |
|
|
|
from triton_kernels.numerics import InFlexData, OutFlexData |
|
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, dequantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE |
|
|
|
from triton_kernels.testing import assert_close, compute_actual_scale |
|
|
|
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
def alloc_rand(shape, device, dtype, requires_grad=True): |
|
if dtype.itemsize == 1: |
|
tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) |
|
return tmp.to(dtype).requires_grad_(requires_grad) |
|
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) |
|
|
|
|
|
def alloc_rand_like(x): |
|
return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) |
|
|
|
|
|
def mask_indx(idx, n_expts_act): |
|
idx.src_indx[idx.dst_indx[-n_expts_act:]] = -1 |
|
idx.dst_indx[-n_expts_act:] = -1 |
|
return idx |
|
|
|
|
|
def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"): |
|
logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True) |
|
routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards) |
|
routing_data.gate_scal = None |
|
gather_idx = gather_idx if do_gather else None |
|
scatter_idx = scatter_idx if do_scatter else None |
|
return m, routing_data, gather_idx, scatter_idx |
|
|
|
|
|
def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype, |
|
has_y_gammas, requires_grad=True, device="cuda"): |
|
torch.manual_seed(0) |
|
assert mode in {'batched', "plain", 'ragged'} |
|
in_m = m * (n_expts_act if gindx is None else 1) |
|
shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k) |
|
shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, ) |
|
x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad) |
|
w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad) |
|
bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad) |
|
gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) |
|
gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) |
|
gs0 = gs0.detach().requires_grad_(requires_grad) |
|
gs1 = gs1.detach().requires_grad_(requires_grad) |
|
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): |
|
gs0 = None |
|
gs1 = None |
|
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: |
|
w = w.transpose(-1, -2).contiguous().transpose(-1, -2) |
|
return x, w, bias, gs0, gs1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, device="cuda"): |
|
weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp |
|
|
|
make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + |
|
([val0] |
|
if n_expts_tot % 2 else []), dtype=torch.float32, device=device) |
|
make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device) |
|
in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale) |
|
) if use_flex else InFlexData() |
|
in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=weight_dtype, scale=make_tensor(scale0, scale1) |
|
) if use_flex else InFlexData() |
|
out_flex_data = lambda scale, use_flex: OutFlexData(dtype=out_dtype, expected_scale=make_scalar( |
|
scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData() |
|
flex_ctx = FlexCtx( |
|
lhs_data=in_flex_data(1.25, act_use_flexpoint), |
|
rhs_data=in_flex_edata(1.50, 1.25, weight_use_flexpoint), |
|
out_data=out_flex_data(4.00, act_use_flexpoint), |
|
) |
|
return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0, |
|
out_dtype=out_dtype) |
|
|
|
|
|
def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config): |
|
flex_ctx = precision_config.flex_ctx |
|
|
|
def apply(x, scale): |
|
if scale is None: |
|
x = x.clone() |
|
elif scale.numel() == 1: |
|
x = x.float() * scale |
|
else: |
|
assert x.ndim == 3 |
|
assert scale.numel() == x.shape[0] |
|
x = x.float() * scale[:, None, None] |
|
return x.detach().requires_grad_() |
|
|
|
return ( |
|
apply(x_tri, flex_ctx.lhs_data.scale), |
|
apply(w_tri, flex_ctx.rhs_data.scale), |
|
apply(bias_tri, None), |
|
None if gs0_tri is None else apply(gs0_tri, None), |
|
None if gs1_tri is None else apply(gs1_tri, None), |
|
) |
|
|
|
|
|
def dtype_str_to_torch(dtype_str: str) -> torch.dtype: |
|
return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) |
|
|
|
|
|
|
|
@pytest.fixture |
|
def opt_flags_scope(request): |
|
yield |
|
opt_flags.reset_opt_flags_constraints() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class Case: |
|
m: int |
|
n: int |
|
k: int |
|
mode: str |
|
act_dtype_str: str |
|
weight_dtype_str: str |
|
n_expts_tot: int = 1 |
|
n_expts_act: int = 1 |
|
n_expt_shards: int = 1 |
|
split_k: int = 1 |
|
hbm_swizzling: bool = False |
|
epilogue_subtile: Union[int, None] = None |
|
|
|
|
|
@pytest.mark.parametrize( |
|
", ".join(f.name for f in fields(Case)), |
|
[ |
|
tuple(getattr(case, f.name) for f in fields(Case)) for case in [ |
|
|
|
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), |
|
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2), |
|
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4), |
|
Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2), |
|
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), |
|
Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), |
|
Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), |
|
Case(16, 256, 256, "batched", "float16", "float16", 5, 1), |
|
Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), |
|
Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), |
|
Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), |
|
Case(300, 400, 400, "batched", "float16", "float16", 5, 1), |
|
Case(300, 400, 400, "ragged", "float16", "float16"), |
|
Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), |
|
Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2), |
|
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), |
|
Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), |
|
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), |
|
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), |
|
|
|
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), |
|
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), |
|
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1), |
|
Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), |
|
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), |
|
Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
|
Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), |
|
Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
|
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), |
|
Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
|
Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), |
|
Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), |
|
Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), |
|
Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
|
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
|
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
|
Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1), |
|
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9), |
|
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
|
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2), |
|
Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
|
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4), |
|
Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
|
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4), |
|
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), |
|
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4), |
|
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), |
|
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
|
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False), |
|
Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), |
|
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), |
|
Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), |
|
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), |
|
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), |
|
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), |
|
Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), |
|
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), |
|
Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), |
|
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4), |
|
Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), |
|
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4), |
|
Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), |
|
|
|
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), |
|
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), |
|
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), |
|
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2), |
|
Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), |
|
Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), |
|
Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), |
|
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2), |
|
Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2), |
|
] |
|
], |
|
) |
|
@pytest.mark.parametrize("block_m", [16, 128]) |
|
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ |
|
(False, False, False), |
|
(True, False, False), |
|
(False, True, False), |
|
(True, True, False), |
|
(True, True, True), |
|
]) |
|
@pytest.mark.parametrize("has_y_gammas", [False, True]) |
|
@pytest.mark.parametrize("is_persistent", [False, True]) |
|
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot, |
|
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile, |
|
device, opt_flags_scope, fresh_knobs): |
|
|
|
if is_cuda(): |
|
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: |
|
pytest.skip("Float8 not tested on A100") |
|
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10: |
|
pytest.skip("float16 x mx not supported with cuda capability >= 10") |
|
if weight_dtype_str.startswith("mx"): |
|
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: |
|
pytest.skip("float8 x mx not supported with cuda capability < 10") |
|
if act_dtype_str == "mxfloat8_e4m3fn": |
|
if is_persistent: |
|
pytest.skip("mx x mx not supported with persistent kernel") |
|
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9: |
|
pytest.skip("Not enough memory on A100") |
|
|
|
elif is_hip(): |
|
if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4(): |
|
pytest.skip("float8 x mx only supported on CDNA4") |
|
if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str: |
|
pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU") |
|
if act_dtype_str.startswith("mx") and weight_dtype_str.startswith("mx"): |
|
pytest.skip("NYI: mx x mx not tested on AMD GPU") |
|
if is_persistent: |
|
pytest.skip("NYI: Persistent kernel not supported on AMD GPU") |
|
if split_k > 1: |
|
pytest.skip("splitK hasn't been fully tested on AMD GPU.") |
|
|
|
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): |
|
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") |
|
|
|
if fused_scatter and split_k > 1: |
|
pytest.skip("fused scatter scratchpad not supported with split_k") |
|
if hbm_swizzling: |
|
if is_hip(): |
|
pytest.skip("NYI. HBM swizzling just implemented for CUDA.") |
|
if torch.cuda.get_device_capability()[0] < 9: |
|
pytest.skip("NYI. Ampere swizzling.") |
|
if torch.cuda.get_device_capability()[0] < 10: |
|
if "mxfloat4" not in weight_dtype_str: |
|
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.") |
|
if k % 64 != 0 or n % 64 != 0: |
|
|
|
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") |
|
|
|
|
|
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) |
|
|
|
torch.manual_seed(0) |
|
|
|
block_k = None |
|
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: |
|
|
|
|
|
|
|
block_k = 256 |
|
|
|
constraints = { |
|
"block_m": block_m, |
|
"block_k": block_k, |
|
"split_k": split_k, |
|
"fused_scatter": fused_scatter, |
|
"is_persistent": is_persistent, |
|
"epilogue_subtile": epilogue_subtile, |
|
} |
|
opt_flags.update_opt_flags_constraints(constraints) |
|
|
|
weight_mxfp = weight_dtype_str.startswith("mx") |
|
if weight_mxfp: |
|
weight_dtype_str = weight_dtype_str[2:] |
|
act_mxfp8 = act_dtype_str.startswith("mx") |
|
act_is_float8 = act_dtype_str.startswith("float8") |
|
if act_mxfp8: |
|
act_dtype_str = act_dtype_str[2:] |
|
dequantize_mxfp8_spec = FnSpecs( |
|
FnName.DEQUANTIZE_MXFP8.name, dequantize_mxfp8_fn, (), () |
|
) |
|
|
|
test_bwd = False |
|
weight_dtype = dtype_str_to_torch(weight_dtype_str) |
|
act_dtype = dtype_str_to_torch(act_dtype_str) |
|
precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, n_expts_tot // n_expt_shards, device=device) |
|
|
|
if mode == "ragged": |
|
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, |
|
device=device) |
|
else: |
|
rdata = gindx = sindx = None |
|
x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, |
|
n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype, |
|
torch.bfloat16 if weight_mxfp else weight_dtype, |
|
has_y_gammas, requires_grad=test_bwd, device=device) |
|
x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt) |
|
|
|
if w_tri.shape[0] == 1: |
|
|
|
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd) |
|
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd) |
|
|
|
if weight_mxfp: |
|
mx_axis = w_tri.ndim - 2 |
|
|
|
w_layout, w_layout_opts = layout.StridedLayout, dict() |
|
w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() |
|
if hbm_swizzling and "float4" in weight_dtype_str: |
|
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) |
|
w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( |
|
mx_axis=mx_axis, num_warps=8) |
|
|
|
w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) |
|
w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) |
|
w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype |
|
w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) |
|
w_scale_tri = wrap_torch_tensor(w_scale_tri) |
|
|
|
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) |
|
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) |
|
precision_opt.weight_scale = w_scale_tri |
|
epilogue = None |
|
if act_mxfp8: |
|
x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) |
|
x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) |
|
is_input_batched = x_tri.ndim == 3 |
|
y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape |
|
n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0] |
|
y_shape = (y_shape[0], n_rows, w_tri.shape[-1]) |
|
if sindx is None or mode == "batched": |
|
if not is_input_batched: |
|
y_shape = (y_shape[1], y_shape[2]) |
|
else: |
|
y_shape = (n_rows // rdata.n_expts_act, y_shape[-1]) |
|
y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),) |
|
y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device) |
|
precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale) |
|
epilogue = Epilogue(dequantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0) |
|
else: |
|
y_scale = None |
|
|
|
if test_launch_metadata: |
|
|
|
def _clobber(t, used_mask): |
|
|
|
|
|
if len(used_mask) == 1: |
|
return |
|
elif t.element_size() == 1: |
|
t.view(torch.int8)[~used_mask] = 127 |
|
else: |
|
t[~used_mask] = torch.inf |
|
|
|
if rdata is not None: |
|
n_tokens = rdata.expt_hist.sum().item() |
|
used_expts = (rdata.expt_hist > 0) |
|
_clobber(w_tri, used_expts) |
|
n_w_bytes = used_expts.sum().item() * n * k * w_tri.element_size() |
|
else: |
|
n_tokens = m |
|
n_w_bytes = w_tri.numel() * w_tri.element_size() |
|
|
|
if gindx is not None: |
|
used_x_rows = (gindx.dst_indx.view(-1, n_expts_act) != -1).any(dim=1) |
|
_clobber(x_tri, used_x_rows) |
|
n_x_bytes = used_x_rows.sum().item() * k * x_tri.element_size() |
|
elif rdata is not None: |
|
n_x_bytes = n_tokens * k * x_tri.element_size() |
|
else: |
|
n_x_bytes = x_tri.numel() * x_tri.element_size() |
|
|
|
nbytes = None |
|
|
|
def _hook(launch_metadata): |
|
nonlocal nbytes |
|
metadata = launch_metadata.get() |
|
if "matmul_ogs" in metadata["name"]: |
|
nbytes = metadata["bytes"] |
|
|
|
triton.knobs.runtime.launch_enter_hook = _hook |
|
|
|
if mode == "batched": |
|
rdata, gindx, sindx = None, None, None |
|
flex = precision_opt.flex_ctx |
|
|
|
|
|
try: |
|
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue) |
|
except (opt_flags.InapplicableConstraint, NotImplementedError): |
|
pytest.skip("inapplicable opt_flags constraint") |
|
|
|
sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1 |
|
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1 |
|
y_scale = flex.out_data.expected_scale if act_is_float8 else 1 |
|
|
|
if test_launch_metadata: |
|
if gindx is not None: |
|
n_y_bytes = (gindx.src_indx != -1).sum().item() * n * tri_y.element_size() |
|
elif rdata is not None: |
|
n_y_bytes = n_tokens * n * tri_y.element_size() |
|
else: |
|
n_y_bytes = tri_y.numel() * tri_y.element_size() |
|
assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes |
|
triton.knobs.runtime.launch_enter_hook = None |
|
|
|
def round_x(x, idx): |
|
return x.to(act_dtype).to(torch.float32) if sep_gather else x |
|
|
|
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y |
|
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, |
|
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref) |
|
scale = lambda val, scal: val if scal is None else val / scal |
|
if n_expt_shards > 1: |
|
if do_scatter: |
|
indx = sindx.dst_indx[sindx.dst_indx != -1] |
|
ref_y = ref_y[indx // n_expts_act, :] |
|
if act_is_float8: |
|
tri_y = tri_y.view(torch.int8) |
|
tri_y = tri_y[indx // n_expts_act, :] |
|
if act_is_float8: |
|
tri_y = tri_y.view(act_dtype) |
|
else: |
|
n_rows = rdata.expt_hist.sum() |
|
assert n_rows > 0 |
|
ref_y = ref_y[:n_rows] |
|
tri_y = tri_y[:n_rows] |
|
if act_mxfp8: |
|
tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) |
|
ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1) |
|
ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) |
|
maxtol = 4e-1 |
|
rmstol = 4e-2 |
|
else: |
|
maxtol = None |
|
rmstol = None |
|
assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol) |
|
|
|
if act_is_float8: |
|
tri_y_scale = flex.out_data.actual_scale.clone() |
|
ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype) |
|
assert (ref_y_scale - |
|
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" |
|
|
|
|
|
def test_set_idle_sms(): |
|
if not is_cuda(): |
|
pytest.skip("Only supported on CUDA") |
|
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags |
|
num_idle_sms = 24 |
|
matmul_ogs_set_idle_sms(num_idle_sms) |
|
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ |
|
1024, 1024, 1024, None, True, False, 1) |
|
assert flags.idle_sms == num_idle_sms |
|
|
|
|
|
@pytest.mark.parametrize("m, n, k, mode", [ |
|
(1200, 704, 608, "ragged"), |
|
(800, 800, 400, "batched"), |
|
]) |
|
@pytest.mark.parametrize("split_k", [1, 2]) |
|
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ |
|
(False, False, False), |
|
(True, False, False), |
|
(False, True, False), |
|
(True, True, False), |
|
(True, True, True), |
|
]) |
|
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [ |
|
(False, None), |
|
(True, 1), |
|
(True, 4), |
|
]) |
|
@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [ |
|
(1.1, 1.4), |
|
(1.0, 1.2), |
|
(0.7, 1.0), |
|
]) |
|
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile, |
|
swiglu_alpha, swiglu_limit, device, opt_flags_scope): |
|
if fused_scatter and split_k > 1: |
|
pytest.skip("fused scatter scratchpad not supported with split_k") |
|
torch.manual_seed(0) |
|
constraints = { |
|
"is_persistent": is_persistent, |
|
"epilogue_subtile": epilogue_subtile, |
|
"fused_scatter": fused_scatter, |
|
"split_k": split_k, |
|
} |
|
n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1 |
|
opt_flags.update_opt_flags_constraints(constraints) |
|
|
|
weight_dtype, act_dtype = torch.float16, torch.float16 |
|
if mode == "ragged": |
|
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, |
|
device=device) |
|
else: |
|
rdata = gindx = sindx = None |
|
|
|
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device) |
|
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, |
|
act_dtype, weight_dtype, False, requires_grad=False, device=device) |
|
|
|
if mode == "batched": |
|
rdata, gindx, sindx = None, None, None |
|
|
|
try: |
|
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha, |
|
precision_config=SwiGLUPrecisionConfig(swiglu_limit)) |
|
b = matmul_ogs( |
|
x, w, bias, rdata, gindx, sindx, precision_opt, |
|
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), |
|
(swiglu_alpha, swiglu_limit), 2)) |
|
except opt_flags.InapplicableConstraint: |
|
pytest.skip("inapplicable constraint") |
|
assert_close(a, b) |
|
|