File size: 4,750 Bytes
567c8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pytest
from triton._internal_testing import is_cuda
from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4
from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
from triton_kernels.target_info import cuda_capability_geq
import triton.language as tl
import triton
import torch

# ------------------------------------------------------------
# Torch tests
# ------------------------------------------------------------


@pytest.mark.parametrize("shape", [(16, 32), (16, 64), (32, 32), (32, 64), (64, 128), (128, 128)])
@pytest.mark.parametrize("trans", [False, True])
@pytest.mark.parametrize("mx_axis", [0, 1])
@pytest.mark.parametrize("mma_version", [2, 3])
def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version):
    x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
    if trans:
        x = x.mT
    if x.shape[1 - mx_axis] < 32:
        pytest.skip("Not enough elements along non-mx axis")
    layout = HopperMXValueLayout(x.shape, mx_axis, mma_version)
    res = layout.unswizzle_data(layout.swizzle_data(x))
    assert (res == x).all()


@pytest.mark.parametrize("mx_axis", [0, 1])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)])
def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps):
    x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda")
    layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps)
    res = layout.unswizzle_data(layout.swizzle_data(x))
    assert (res[:shape[0], :shape[1]] == x).all()


# ------------------------------------------------------------
# Triton tests
# ------------------------------------------------------------

# ------------------ upcast mxfp4 to bf16 --------------------


@triton.jit
def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m, x_scale_stride_n, y_stride_m,
                          y_stride_n, X_BLOCK_M: tl.constexpr, X_BLOCK_N: tl.constexpr, Y_BLOCK_M: tl.constexpr,
                          Y_BLOCK_N: tl.constexpr, SCALE_BLOCK_M: tl.constexpr, SCALE_BLOCK_N: tl.constexpr,
                          mx_axis: tl.constexpr):
    offs_m_val = tl.arange(0, X_BLOCK_M)
    offs_n_val = tl.arange(0, X_BLOCK_N)
    offs_m_scale = tl.arange(0, SCALE_BLOCK_M)
    offs_n_scale = tl.arange(0, SCALE_BLOCK_N)
    # load values
    offs_x = offs_m_val[:, None] * x_stride_m + offs_n_val[None, :] * x_stride_n
    x = tl.load(X + offs_x)
    # load scales
    offs_x_scale = offs_m_scale[:, None] * x_scale_stride_m + offs_n_scale[None, :] * x_scale_stride_n
    x_scale = tl.load(XScale + offs_x_scale)
    x_scale = unswizzle_mxfp4_scale_hopper(x_scale, mx_axis=mx_axis, num_warps=tl.extra.cuda.num_warps())
    y = mxfp4_to_bf16_triton(x, x_scale, mx_axis=mx_axis)
    # write back output
    offs_m_val = tl.arange(0, Y_BLOCK_M)
    offs_n_val = tl.arange(0, Y_BLOCK_N)
    offs_y = offs_m_val[:, None] * y_stride_m + offs_n_val[None, :] * y_stride_n
    tl.store(Y + offs_y, y)


@pytest.mark.skipif(not is_cuda(), reason="Only supported on cuda")
@pytest.mark.skipif(not cuda_capability_geq(9), reason="Only supported for capability >= 9")
def test_upcast_mxfp4_to_bf16():
    mx_axis = 0
    num_warps = 4
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    shape = (256, 128)
    x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
    x_fp4_val, x_fp4_scale = downcast_to_mxfp(x, torch.uint8, axis=mx_axis)
    x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis)
    x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4)
    x_fp4_scale = wrap_torch_tensor(x_fp4_scale)
    x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis)
    x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps)
    y = torch.empty_like(x_bf16)
    _upcast_mxfp4_to_bf16[(1, )](
        y, x_fp4_val.storage.data, x_fp4_scale.storage.data,  #
        x_fp4_val.storage.data.stride(0), x_fp4_val.storage.data.stride(1),  #
        x_fp4_scale.storage.data.stride(0), x_fp4_scale.storage.data.stride(1),  #
        y.stride(0), y.stride(1),  #
        x_fp4_val.storage.data.shape[0], x_fp4_val.storage.data.shape[1],  #
        shape[0], shape[1],  #
        x_fp4_scale.storage.data.shape[0], x_fp4_scale.storage.data.shape[1],  #
        mx_axis=mx_axis, num_warps=num_warps)
    assert (y == x_bf16).all()