File size: 28,898 Bytes
567c8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
# isort: off
# fmt: off
from dataclasses import dataclass, fields, replace
import pytest
import torch
from typing import Union
import triton
# routing utilities
from triton_kernels.routing import routing
# matmul utilities
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue
from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4
from triton_kernels.tensor_details import layout
# numerics utilities
from triton_kernels.numerics import InFlexData, OutFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, dequantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE
# testing utilities
from triton_kernels.testing import assert_close, compute_actual_scale
# target-specific utilities
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4

# ---------------
# initialize data
# ---------------


def alloc_rand(shape, device, dtype, requires_grad=True):
    if dtype.itemsize == 1:
        tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16))
        return tmp.to(dtype).requires_grad_(requires_grad)
    return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)


def alloc_rand_like(x):
    return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad)


def mask_indx(idx, n_expts_act):
    idx.src_indx[idx.dst_indx[-n_expts_act:]] = -1
    idx.dst_indx[-n_expts_act:] = -1
    return idx


def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"):
    logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True)
    routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards)
    routing_data.gate_scal = None
    gather_idx = gather_idx if do_gather else None
    scatter_idx = scatter_idx if do_scatter else None
    return m, routing_data, gather_idx, scatter_idx


def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype,
                      has_y_gammas, requires_grad=True, device="cuda"):
    torch.manual_seed(0)
    assert mode in {'batched', "plain", 'ragged'}
    in_m = m * (n_expts_act if gindx is None else 1)
    shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k)
    shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, )
    x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad)
    w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad)
    bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
    gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
    gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad)
    gs0 = gs0.detach().requires_grad_(requires_grad)
    gs1 = gs1.detach().requires_grad_(requires_grad)
    if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
        gs0 = None
        gs1 = None
    if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
        w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
    return x, w, bias, gs0, gs1


# ---------------
# numerics stuff
# ---------------


def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, device="cuda"):
    weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp
    # flexpoint
    make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) +
                                                  ([val0]
                                                   if n_expts_tot % 2 else []), dtype=torch.float32, device=device)
    make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device)
    in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale)
                                                      ) if use_flex else InFlexData()
    in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=weight_dtype, scale=make_tensor(scale0, scale1)
                                                                ) if use_flex else InFlexData()
    out_flex_data = lambda scale, use_flex: OutFlexData(dtype=out_dtype, expected_scale=make_scalar(
        scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData()
    flex_ctx = FlexCtx(
        lhs_data=in_flex_data(1.25, act_use_flexpoint),
        rhs_data=in_flex_edata(1.50, 1.25, weight_use_flexpoint),
        out_data=out_flex_data(4.00, act_use_flexpoint),
    )
    return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0,
                           out_dtype=out_dtype)


def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config):
    flex_ctx = precision_config.flex_ctx

    def apply(x, scale):
        if scale is None:
            x = x.clone()
        elif scale.numel() == 1:
            x = x.float() * scale
        else:
            assert x.ndim == 3
            assert scale.numel() == x.shape[0]
            x = x.float() * scale[:, None, None]
        return x.detach().requires_grad_()

    return (
        apply(x_tri, flex_ctx.lhs_data.scale),
        apply(w_tri, flex_ctx.rhs_data.scale),
        apply(bias_tri, None),
        None if gs0_tri is None else apply(gs0_tri, None),
        None if gs1_tri is None else apply(gs1_tri, None),
    )


def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
    return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str)


# Scope to ensure that the opt_flags_constraints are reset after the test
@pytest.fixture
def opt_flags_scope(request):
    yield
    opt_flags.reset_opt_flags_constraints()


# ---------------
# unit tests
# ---------------


@dataclass
class Case:
    m: int
    n: int
    k: int
    mode: str
    act_dtype_str: str
    weight_dtype_str: str
    n_expts_tot: int = 1
    n_expts_act: int = 1
    n_expt_shards: int = 1
    split_k: int = 1
    hbm_swizzling: bool = False
    epilogue_subtile: Union[int, None] = None


@pytest.mark.parametrize(
    ", ".join(f.name for f in fields(Case)),
    [
        tuple(getattr(case, f.name) for f in fields(Case)) for case in [
            # Non-mx types:
            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4),
            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2),
            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4),
            Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2),
            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
            Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3),
            Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1),
            Case(16, 256, 256, "batched", "float16", "float16", 5, 1),
            Case(16, 256, 256, "ragged", "float16", "float16", 3, 1),
            Case(256, 256, 256, "ragged", "float16", "float16", 4, 1),
            Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3),
            Case(300, 400, 400, "batched", "float16", "float16", 5, 1),
            Case(300, 400, 400, "ragged", "float16", "float16"),
            Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"),
            Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2),
            Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2),
            Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1),
            Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2),
            Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
            # mx types:
            Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
            Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
            Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1),
            Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
            Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
            Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
            Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9),
            Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
            Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4),
            Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
            Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4),
            Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2),
            Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4),
            Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
            Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
            Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
            Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1),
            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9),
            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2),
            Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
            Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4),
            Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
            Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4),
            Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
            Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),
            Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
            Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
            Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False),
            Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
            Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True),
            Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1),
            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9),
            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True),
            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2),
            Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True),
            Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4),
            Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True),
            Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4),
            Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
            Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4),
            Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
            # AMD
            Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
            Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
            Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2),
            Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2),
            Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2),
            Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"),
            Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1),
            Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2),
            Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2),
        ]
    ],
)
@pytest.mark.parametrize("block_m", [16, 128])
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
    (False, False, False),
    (True, False, False),
    (False, True, False),
    (True, True, False),
    (True, True, True),
])
@pytest.mark.parametrize("has_y_gammas", [False, True])
@pytest.mark.parametrize("is_persistent", [False, True])
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
            n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
            device, opt_flags_scope, fresh_knobs):
    # TODO: remove when Triton FP8 supports proper RTNE
    if is_cuda():
        if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
            pytest.skip("Float8 not tested on A100")
        if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
            pytest.skip("float16 x mx not supported with cuda capability >= 10")
        if weight_dtype_str.startswith("mx"):
            if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:
                pytest.skip("float8 x mx not supported with cuda capability < 10")
            if act_dtype_str == "mxfloat8_e4m3fn":
                if is_persistent:
                    pytest.skip("mx x mx not supported with persistent kernel")
        if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
            pytest.skip("Not enough memory on A100")

    elif is_hip():
        if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4():
            pytest.skip("float8 x mx only supported on CDNA4")
        if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str:
            pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU")
        if act_dtype_str.startswith("mx") and weight_dtype_str.startswith("mx"):
            pytest.skip("NYI: mx x mx not tested on AMD GPU")
        if is_persistent:
            pytest.skip("NYI: Persistent kernel not supported on AMD GPU")
        if split_k > 1:
            pytest.skip("splitK hasn't been fully tested on AMD GPU.")

    if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
        pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")

    if fused_scatter and split_k > 1:
        pytest.skip("fused scatter scratchpad not supported with split_k")
    if hbm_swizzling:
        if is_hip():
            pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
        if torch.cuda.get_device_capability()[0] < 9:
            pytest.skip("NYI. Ampere swizzling.")
        if torch.cuda.get_device_capability()[0] < 10:
            if "mxfloat4" not in weight_dtype_str:
                pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
            if k % 64 != 0 or n % 64 != 0:
                # Automatic padding not implemented for Hopper swizzle
                pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")

    # launch metadata for batched / mx types may not work yet.
    test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)

    torch.manual_seed(0)

    block_k = None
    if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
        # Override block_k for testing correctness. The default is temporarily 128 for
        # performance reasons which doesn't work with persistent matmul.
        # TODO: revisit when Triton is better for H100 + MXFP4
        block_k = 256

    constraints = {
        "block_m": block_m,
        "block_k": block_k,
        "split_k": split_k,
        "fused_scatter": fused_scatter,
        "is_persistent": is_persistent,
        "epilogue_subtile": epilogue_subtile,
    }
    opt_flags.update_opt_flags_constraints(constraints)

    weight_mxfp = weight_dtype_str.startswith("mx")
    if weight_mxfp:
        weight_dtype_str = weight_dtype_str[2:]
    act_mxfp8 = act_dtype_str.startswith("mx")
    act_is_float8 = act_dtype_str.startswith("float8")
    if act_mxfp8:
        act_dtype_str = act_dtype_str[2:]
        dequantize_mxfp8_spec = FnSpecs(
            FnName.DEQUANTIZE_MXFP8.name, dequantize_mxfp8_fn, (), ()
        )

    test_bwd = False
    weight_dtype = dtype_str_to_torch(weight_dtype_str)
    act_dtype = dtype_str_to_torch(act_dtype_str)
    precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, n_expts_tot // n_expt_shards, device=device)
    # precision_opt.x_pad_trans_requires_flexpoint = False
    if mode == "ragged":
        m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
                                                   device=device)
    else:
        rdata = gindx = sindx = None
    x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act,
                                                                 n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype,  #
                                                                 torch.bfloat16 if weight_mxfp else weight_dtype,
                                                                 has_y_gammas, requires_grad=test_bwd, device=device)
    x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt)

    if w_tri.shape[0] == 1:
        # Test the case when weight has dim 2, i.e., shape (K, N).
        w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
        w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)

    if weight_mxfp:
        mx_axis = w_tri.ndim - 2
        # compute layouts
        w_layout, w_layout_opts = layout.StridedLayout, dict()
        w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict()
        if hbm_swizzling and "float4" in weight_dtype_str:
            w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis)
            w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout(
                mx_axis=mx_axis, num_warps=8)
        # downcast to mxfp
        w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis)
        w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis)
        w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
        w_tri = wrap_torch_tensor(w_tri, w_tri_dtype)
        w_scale_tri = wrap_torch_tensor(w_scale_tri)
        # convert layouts
        w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
        w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
        precision_opt.weight_scale = w_scale_tri
    epilogue = None
    if act_mxfp8:
        x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1)
        x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1)
        is_input_batched = x_tri.ndim == 3
        y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape
        n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0]
        y_shape = (y_shape[0], n_rows, w_tri.shape[-1])
        if sindx is None or mode == "batched":
            if not is_input_batched:
                y_shape = (y_shape[1], y_shape[2])
        else:
            y_shape = (n_rows // rdata.n_expts_act, y_shape[-1])
        y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),)
        y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device)
        precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale)
        epilogue = Epilogue(dequantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0)
    else:
        y_scale = None

    if test_launch_metadata:

        def _clobber(t, used_mask):
            # Fill the unread part of the tensor with garbage, to be sure that
            # we don't actually read from the part.
            if len(used_mask) == 1:
                return
            elif t.element_size() == 1:
                t.view(torch.int8)[~used_mask] = 127
            else:
                t[~used_mask] = torch.inf

        if rdata is not None:
            n_tokens = rdata.expt_hist.sum().item()
            used_expts = (rdata.expt_hist > 0)
            _clobber(w_tri, used_expts)
            n_w_bytes = used_expts.sum().item() * n * k * w_tri.element_size()
        else:
            n_tokens = m
            n_w_bytes = w_tri.numel() * w_tri.element_size()

        if gindx is not None:
            used_x_rows = (gindx.dst_indx.view(-1, n_expts_act) != -1).any(dim=1)
            _clobber(x_tri, used_x_rows)
            n_x_bytes = used_x_rows.sum().item() * k * x_tri.element_size()
        elif rdata is not None:
            n_x_bytes = n_tokens * k * x_tri.element_size()
        else:
            n_x_bytes = x_tri.numel() * x_tri.element_size()

        nbytes = None

        def _hook(launch_metadata):
            nonlocal nbytes
            metadata = launch_metadata.get()
            if "matmul_ogs" in metadata["name"]:
                nbytes = metadata["bytes"]

        triton.knobs.runtime.launch_enter_hook = _hook

    if mode == "batched":
        rdata, gindx, sindx = None, None, None
    flex = precision_opt.flex_ctx

    # triton
    try:
        tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue)
    except (opt_flags.InapplicableConstraint, NotImplementedError):
        pytest.skip("inapplicable opt_flags constraint")
    # If split_k > 1, then the intermediate tensor is fp32.
    sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1
    sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
    y_scale = flex.out_data.expected_scale if act_is_float8 else 1

    if test_launch_metadata:
        if gindx is not None:
            n_y_bytes = (gindx.src_indx != -1).sum().item() * n * tri_y.element_size()
        elif rdata is not None:
            n_y_bytes = n_tokens * n * tri_y.element_size()
        else:
            n_y_bytes = tri_y.numel() * tri_y.element_size()
        assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes
        triton.knobs.runtime.launch_enter_hook = None

    def round_x(x, idx):
        return x.to(act_dtype).to(torch.float32) if sep_gather else x

    round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
    ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref,  #
                             rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
    scale = lambda val, scal: val if scal is None else val / scal
    if n_expt_shards > 1:
        if do_scatter:
            indx = sindx.dst_indx[sindx.dst_indx != -1]
            ref_y = ref_y[indx // n_expts_act, :]
            if act_is_float8:
                tri_y = tri_y.view(torch.int8)
            tri_y = tri_y[indx // n_expts_act, :]
            if act_is_float8:
                tri_y = tri_y.view(act_dtype)
        else:
            n_rows = rdata.expt_hist.sum()
            assert n_rows > 0
            ref_y = ref_y[:n_rows]
            tri_y = tri_y[:n_rows]
    if act_mxfp8:
        tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, dtype=torch.bfloat16, axis=-1).to(ref_y.dtype)
        ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1)
        ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1)
        maxtol = 4e-1
        rmstol = 4e-2
    else:
        maxtol = None
        rmstol = None
    assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol)

    if act_is_float8:
        tri_y_scale = flex.out_data.actual_scale.clone()
        ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype)
        assert (ref_y_scale -
                tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"


def test_set_idle_sms():
    if not is_cuda():
        pytest.skip("Only supported on CUDA")
    from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
    num_idle_sms = 24
    matmul_ogs_set_idle_sms(num_idle_sms)
    flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
                           1024, 1024, 1024, None, True, False, 1)
    assert flags.idle_sms == num_idle_sms


@pytest.mark.parametrize("m, n, k, mode", [
    (1200, 704, 608, "ragged"),
    (800, 800, 400, "batched"),
])
@pytest.mark.parametrize("split_k", [1, 2])
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
    (False, False, False),
    (True, False, False),
    (False, True, False),
    (True, True, False),
    (True, True, True),
])
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
    (False, None),
    (True, 1),
    (True, 4),
])
@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [
    (1.1, 1.4),
    (1.0, 1.2),
    (0.7, 1.0),
])
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile,
                   swiglu_alpha, swiglu_limit, device, opt_flags_scope):
    if fused_scatter and split_k > 1:
        pytest.skip("fused scatter scratchpad not supported with split_k")
    torch.manual_seed(0)
    constraints = {
        "is_persistent": is_persistent,
        "epilogue_subtile": epilogue_subtile,
        "fused_scatter": fused_scatter,
        "split_k": split_k,
    }
    n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1
    opt_flags.update_opt_flags_constraints(constraints)

    weight_dtype, act_dtype = torch.float16, torch.float16
    if mode == "ragged":
        m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
                                                   device=device)
    else:
        rdata = gindx = sindx = None

    precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device)
    x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
                                         act_dtype, weight_dtype, False, requires_grad=False, device=device)

    if mode == "batched":
        rdata, gindx, sindx = None, None, None

    try:
        a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
                   precision_config=SwiGLUPrecisionConfig(swiglu_limit))
        b = matmul_ogs(
            x, w, bias, rdata, gindx, sindx, precision_opt,
            fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
                                             (swiglu_alpha, swiglu_limit), 2))
    except opt_flags.InapplicableConstraint:
        pytest.skip("inapplicable constraint")
    assert_close(a, b)