import pytest | |
import torch | |
from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout | |
# ------------------------------------------------------------ | |
# Torch tests | |
# ------------------------------------------------------------ | |
def test_mxfp4_scale_roundtrip(shape): | |
x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") | |
layout = BlackwellMXScaleLayout(x.shape) | |
res = layout.unswizzle_data(layout.swizzle_data(x)) | |
assert (res == x).all() | |