Implement make_opt_flags function for XPU
#5
by
YangKai0616
- opened
- tests/conftest.py +13 -2
- tests/test_matmul.py +13 -12
- torch-ext/triton_kernels/matmul_ogs.py +2 -1
- torch-ext/triton_kernels/matmul_ogs_details/_common.py +13 -1
- torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py +6 -5
- torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +3 -3
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py +80 -1
- torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py +41 -0
- torch-ext/triton_kernels/numerics_details/flexpoint.py +2 -1
- torch-ext/triton_kernels/swiglu.py +1 -1
- torch-ext/triton_kernels/target_info.py +47 -26
tests/conftest.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import pytest
|
2 |
-
|
3 |
|
4 |
def pytest_addoption(parser):
|
5 |
parser.addoption("--device", action="store", default="cuda")
|
@@ -12,8 +12,19 @@ def device(request):
|
|
12 |
|
13 |
@pytest.fixture
|
14 |
def fresh_knobs(monkeypatch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
from triton._internal_testing import _fresh_knobs_impl
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
try:
|
18 |
yield fresh_function()
|
19 |
finally:
|
|
|
1 |
import pytest
|
2 |
+
import triton
|
3 |
|
4 |
def pytest_addoption(parser):
|
5 |
parser.addoption("--device", action="store", default="cuda")
|
|
|
12 |
|
13 |
@pytest.fixture
|
14 |
def fresh_knobs(monkeypatch):
|
15 |
+
try:
|
16 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
17 |
+
_parts = _ver_str.split(".")
|
18 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
19 |
+
except Exception:
|
20 |
+
_ver_tuple = (0, 0, 0)
|
21 |
+
|
22 |
from triton._internal_testing import _fresh_knobs_impl
|
23 |
+
if _ver_tuple > (3, 4, 0):
|
24 |
+
fresh_function, reset_function = _fresh_knobs_impl()
|
25 |
+
else:
|
26 |
+
fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
|
27 |
+
|
28 |
try:
|
29 |
yield fresh_function()
|
30 |
finally:
|
tests/test_matmul.py
CHANGED
@@ -20,7 +20,7 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
|
|
20 |
# testing utilities
|
21 |
from triton_kernels.testing import assert_close, compute_actual_scale
|
22 |
# target-specific utilities
|
23 |
-
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
|
24 |
|
25 |
# ---------------
|
26 |
# initialize data
|
@@ -70,7 +70,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
|
|
70 |
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
|
71 |
gs0 = None
|
72 |
gs1 = None
|
73 |
-
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
|
74 |
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
|
75 |
return x, w, bias, gs0, gs1
|
76 |
|
@@ -291,14 +291,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
291 |
if hbm_swizzling:
|
292 |
if is_hip():
|
293 |
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
|
294 |
-
if
|
295 |
-
|
296 |
-
|
297 |
-
if
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
|
|
302 |
|
303 |
# launch metadata for batched / mx types may not work yet.
|
304 |
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
|
@@ -306,7 +307,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
306 |
torch.manual_seed(0)
|
307 |
|
308 |
block_k = None
|
309 |
-
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
|
310 |
# Override block_k for testing correctness. The default is temporarily 128 for
|
311 |
# performance reasons which doesn't work with persistent matmul.
|
312 |
# TODO: revisit when Triton is better for H100 + MXFP4
|
@@ -462,7 +463,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
|
|
462 |
|
463 |
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
|
464 |
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
|
465 |
-
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
|
466 |
scale = lambda val, scal: val if scal is None else val / scal
|
467 |
if n_expt_shards > 1:
|
468 |
if do_scatter:
|
|
|
20 |
# testing utilities
|
21 |
from triton_kernels.testing import assert_close, compute_actual_scale
|
22 |
# target-specific utilities
|
23 |
+
from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
|
24 |
|
25 |
# ---------------
|
26 |
# initialize data
|
|
|
70 |
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
|
71 |
gs0 = None
|
72 |
gs1 = None
|
73 |
+
if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
|
74 |
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
|
75 |
return x, w, bias, gs0, gs1
|
76 |
|
|
|
291 |
if hbm_swizzling:
|
292 |
if is_hip():
|
293 |
pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
|
294 |
+
if is_cuda():
|
295 |
+
if torch.cuda.get_device_capability()[0] < 9:
|
296 |
+
pytest.skip("NYI. Ampere swizzling.")
|
297 |
+
if torch.cuda.get_device_capability()[0] < 10:
|
298 |
+
if "mxfloat4" not in weight_dtype_str:
|
299 |
+
pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
|
300 |
+
if k % 64 != 0 or n % 64 != 0:
|
301 |
+
# Automatic padding not implemented for Hopper swizzle
|
302 |
+
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
|
303 |
|
304 |
# launch metadata for batched / mx types may not work yet.
|
305 |
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
|
|
|
307 |
torch.manual_seed(0)
|
308 |
|
309 |
block_k = None
|
310 |
+
if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
|
311 |
# Override block_k for testing correctness. The default is temporarily 128 for
|
312 |
# performance reasons which doesn't work with persistent matmul.
|
313 |
# TODO: revisit when Triton is better for H100 + MXFP4
|
|
|
463 |
|
464 |
round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
|
465 |
ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
|
466 |
+
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device)
|
467 |
scale = lambda val, scal: val if scal is None else val / scal
|
468 |
if n_expt_shards > 1:
|
469 |
if do_scatter:
|
torch-ext/triton_kernels/matmul_ogs.py
CHANGED
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
602 |
betas = None,
|
603 |
gammas = None,
|
604 |
round_x = None, round_y = None,
|
|
|
605 |
):
|
606 |
is_input_batched = x.ndim == 3
|
607 |
assert x.dtype.itemsize > 1
|
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
|
|
641 |
else:
|
642 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
643 |
batch = i if is_input_batched else 0
|
644 |
-
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=
|
645 |
w[i].float())
|
646 |
if bias is not None:
|
647 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
|
|
602 |
betas = None,
|
603 |
gammas = None,
|
604 |
round_x = None, round_y = None,
|
605 |
+
device: str = "cuda",
|
606 |
):
|
607 |
is_input_batched = x.ndim == 3
|
608 |
assert x.dtype.itemsize > 1
|
|
|
642 |
else:
|
643 |
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
644 |
batch = i if is_input_batched else 0
|
645 |
+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
|
646 |
w[i].float())
|
647 |
if bias is not None:
|
648 |
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
torch-ext/triton_kernels/matmul_ogs_details/_common.py
CHANGED
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
|
|
7 |
# -----------------------------------------------------------------------------
|
8 |
# Utilities
|
9 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
13 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
14 |
mapping = {
|
15 |
tl.float16: "fp16",
|
|
|
7 |
# -----------------------------------------------------------------------------
|
8 |
# Utilities
|
9 |
# -----------------------------------------------------------------------------
|
10 |
+
try:
|
11 |
+
_ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
|
12 |
+
_parts = _ver_str.split(".")
|
13 |
+
_ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
|
14 |
+
except Exception:
|
15 |
+
_ver_tuple = (0, 0, 0)
|
16 |
|
17 |
+
if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
|
18 |
+
_constexpr_function = triton.constexpr_function
|
19 |
+
else:
|
20 |
+
_constexpr_function = tl.constexpr_function
|
21 |
|
22 |
+
|
23 |
+
|
24 |
+
@_constexpr_function
|
25 |
def get_scaled_dot_format_string(dtype: tl.dtype):
|
26 |
mapping = {
|
27 |
tl.float16: "fp16",
|
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py
CHANGED
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
|
|
4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
6 |
from ..target_info import is_hip as _is_hip
|
|
|
7 |
|
8 |
|
9 |
# fmt: off
|
10 |
-
@
|
11 |
def is_hip():
|
12 |
return _is_hip()
|
13 |
|
14 |
|
15 |
-
@
|
16 |
def cuda_capability_geq(x, y):
|
17 |
return _cuda_capability_geq(x, y)
|
18 |
|
19 |
|
20 |
-
@
|
21 |
def log2(n):
|
22 |
return len(bin(n)) - 3
|
23 |
|
24 |
|
25 |
-
@
|
26 |
def _permute_to_end_order(n: int, axis: int):
|
27 |
"""
|
28 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
|
|
105 |
return ret
|
106 |
|
107 |
|
108 |
-
@
|
109 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
110 |
"""
|
111 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
|
|
4 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
5 |
from ..target_info import cuda_capability_geq as _cuda_capability_geq
|
6 |
from ..target_info import is_hip as _is_hip
|
7 |
+
from ._common import _constexpr_function
|
8 |
|
9 |
|
10 |
# fmt: off
|
11 |
+
@_constexpr_function
|
12 |
def is_hip():
|
13 |
return _is_hip()
|
14 |
|
15 |
|
16 |
+
@_constexpr_function
|
17 |
def cuda_capability_geq(x, y):
|
18 |
return _cuda_capability_geq(x, y)
|
19 |
|
20 |
|
21 |
+
@_constexpr_function
|
22 |
def log2(n):
|
23 |
return len(bin(n)) - 3
|
24 |
|
25 |
|
26 |
+
@_constexpr_function
|
27 |
def _permute_to_end_order(n: int, axis: int):
|
28 |
"""
|
29 |
Returns the order of the axes of a tensor to permute `axis` to the end.
|
|
|
106 |
return ret
|
107 |
|
108 |
|
109 |
+
@_constexpr_function
|
110 |
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
111 |
"""
|
112 |
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
CHANGED
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
|
|
12 |
compute_scale,
|
13 |
)
|
14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
15 |
-
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
16 |
|
17 |
|
18 |
-
@
|
19 |
def cuda_capability_geq(major, minor):
|
20 |
return target_info.cuda_capability_geq(major, minor)
|
21 |
|
22 |
-
@
|
23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
24 |
if isinstance(tensor_or_desc, tl.tensor):
|
25 |
return tensor_or_desc.dtype.element_ty
|
|
|
12 |
compute_scale,
|
13 |
)
|
14 |
from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
15 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
|
16 |
|
17 |
|
18 |
+
@_constexpr_function
|
19 |
def cuda_capability_geq(major, minor):
|
20 |
return target_info.cuda_capability_geq(major, minor)
|
21 |
|
22 |
+
@_constexpr_function
|
23 |
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
24 |
if isinstance(tensor_or_desc, tl.tensor):
|
25 |
return tensor_or_desc.dtype.element_ty
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py
CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|
4 |
import triton
|
5 |
from ..target_info import get_cdna_version
|
6 |
import torch
|
7 |
-
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
|
8 |
|
9 |
|
10 |
@dataclass
|
@@ -30,6 +30,83 @@ class OptFlags:
|
|
30 |
raise ValueError("Not supported")
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
def make_default_opt_flags_amd(
|
35 |
out_dtype,
|
@@ -292,6 +369,8 @@ def make_opt_flags(
|
|
292 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
293 |
_opt_flags_constraints]
|
294 |
backend = triton.runtime.driver.active.get_current_target().backend
|
|
|
|
|
295 |
if backend == "hip":
|
296 |
return make_default_opt_flags_amd(*args)
|
297 |
if backend == "cuda":
|
|
|
4 |
import triton
|
5 |
from ..target_info import get_cdna_version
|
6 |
import torch
|
7 |
+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
|
8 |
|
9 |
|
10 |
@dataclass
|
|
|
30 |
raise ValueError("Not supported")
|
31 |
|
32 |
|
33 |
+
def make_default_opt_flags_intel(
|
34 |
+
out_dtype,
|
35 |
+
lhs_dtype,
|
36 |
+
rhs_dtype,
|
37 |
+
precision_config,
|
38 |
+
m,
|
39 |
+
n,
|
40 |
+
k,
|
41 |
+
routing_data,
|
42 |
+
can_use_persistent_tma,
|
43 |
+
can_use_fused_scatter,
|
44 |
+
enforce_bitwise_invariance,
|
45 |
+
epilogue_effective_itemsize,
|
46 |
+
constraints,
|
47 |
+
):
|
48 |
+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
|
49 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
50 |
+
# tokens per expert
|
51 |
+
if routing_data is None:
|
52 |
+
tokens_per_expt = m
|
53 |
+
elif routing_data.expected_tokens_per_expt is None:
|
54 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
55 |
+
else:
|
56 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
57 |
+
# pid swizzling
|
58 |
+
group_m = 8
|
59 |
+
xcd_swizzle = 1
|
60 |
+
# block_m
|
61 |
+
if constraints.get("block_m", None):
|
62 |
+
block_m = constraints["block_m"]
|
63 |
+
elif enforce_bitwise_invariance:
|
64 |
+
block_m = 128
|
65 |
+
else:
|
66 |
+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
|
67 |
+
# block n
|
68 |
+
block_n = opt_flags_intel.compute_block_n(n)
|
69 |
+
# is_persistent
|
70 |
+
is_persistent = constraints.get("is_persistent", False)
|
71 |
+
# block k
|
72 |
+
if constraints.get("block_k", None) is not None:
|
73 |
+
block_k = constraints["block_k"]
|
74 |
+
else:
|
75 |
+
block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
|
76 |
+
# split_k
|
77 |
+
if constraints.get("split_k", None) is not None:
|
78 |
+
split_k = constraints["split_k"]
|
79 |
+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
|
80 |
+
split_k = 1
|
81 |
+
else:
|
82 |
+
estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
|
83 |
+
split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
|
84 |
+
|
85 |
+
epilogue_subtile = constraints.get('epilogue_subtile', None)
|
86 |
+
if epilogue_subtile is None:
|
87 |
+
epilogue_subtile = 1
|
88 |
+
|
89 |
+
ret = OptFlags(
|
90 |
+
block_m=block_m,
|
91 |
+
block_n=block_n,
|
92 |
+
block_k=block_k,
|
93 |
+
num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
|
94 |
+
num_stages=constraints.get("num_stages", 2),
|
95 |
+
fused_scatter=constraints.get('fused_scatter', False),
|
96 |
+
group_m=group_m,
|
97 |
+
xcd_swizzle=xcd_swizzle,
|
98 |
+
w_cache_modifier=None,
|
99 |
+
split_k=split_k,
|
100 |
+
is_persistent=is_persistent,
|
101 |
+
epilogue_subtile=epilogue_subtile,
|
102 |
+
arch=None,
|
103 |
+
target_kernel_kwargs=dict(),
|
104 |
+
idle_sms=0,
|
105 |
+
)
|
106 |
+
# check constraints
|
107 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
108 |
+
return ret
|
109 |
+
|
110 |
|
111 |
def make_default_opt_flags_amd(
|
112 |
out_dtype,
|
|
|
369 |
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
370 |
_opt_flags_constraints]
|
371 |
backend = triton.runtime.driver.active.get_current_target().backend
|
372 |
+
if backend == "xpu":
|
373 |
+
return make_default_opt_flags_intel(*args)
|
374 |
if backend == "hip":
|
375 |
return make_default_opt_flags_amd(*args)
|
376 |
if backend == "cuda":
|
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
|
4 |
+
|
5 |
+
def compute_grid_size(routing_data, m, n, block_m, block_n):
|
6 |
+
if routing_data is not None:
|
7 |
+
grid_m = routing_data.n_blocks(m, block_m)
|
8 |
+
else:
|
9 |
+
grid_m = triton.cdiv(m, block_m)
|
10 |
+
grid_n = (n + block_n - 1) // block_n
|
11 |
+
return grid_m * grid_n
|
12 |
+
|
13 |
+
|
14 |
+
def compute_block_n(n: int):
|
15 |
+
# block_n:
|
16 |
+
return max(16, min(128, triton.next_power_of_2(n)))
|
17 |
+
|
18 |
+
|
19 |
+
def compute_block_k(k: int | None, is_persistent: bool, precision_config):
|
20 |
+
if k is not None:
|
21 |
+
block_k = max(32, min(128, triton.next_power_of_2(k)))
|
22 |
+
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
|
23 |
+
if is_persistent and has_mx_weight_scale:
|
24 |
+
block_k = min(block_k, 128)
|
25 |
+
return block_k
|
26 |
+
|
27 |
+
|
28 |
+
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
|
29 |
+
device_props = torch.xpu.get_device_properties(0)
|
30 |
+
n_sms = device_props.gpu_subslice_count
|
31 |
+
split_k = n_sms // grid_size
|
32 |
+
if k is not None:
|
33 |
+
# avoid split_k for small k
|
34 |
+
num_block_k = triton.cdiv(k, block_k)
|
35 |
+
split_k = min(split_k, num_block_k // 4)
|
36 |
+
split_k = max(split_k, 1)
|
37 |
+
return split_k
|
38 |
+
|
39 |
+
|
40 |
+
def compute_num_warps(block_m, block_n):
|
41 |
+
return max(block_m * block_n // 4096, 4)
|
torch-ext/triton_kernels/numerics_details/flexpoint.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
2 |
from .. import target_info
|
|
|
3 |
import triton
|
4 |
import triton.language as tl
|
5 |
|
@@ -52,7 +53,7 @@ def rcp_max_finite(dtype):
|
|
52 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
53 |
|
54 |
|
55 |
-
@
|
56 |
def cuda_capability_geq(major, minor):
|
57 |
return target_info.cuda_capability_geq(major, minor)
|
58 |
|
|
|
1 |
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
2 |
from .. import target_info
|
3 |
+
from ..matmul_ogs_details._common import _constexpr_function
|
4 |
import triton
|
5 |
import triton.language as tl
|
6 |
|
|
|
53 |
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
54 |
|
55 |
|
56 |
+
@_constexpr_function
|
57 |
def cuda_capability_geq(major, minor):
|
58 |
return target_info.cuda_capability_geq(major, minor)
|
59 |
|
torch-ext/triton_kernels/swiglu.py
CHANGED
@@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function):
|
|
35 |
# optimization hyperparameters
|
36 |
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
37 |
num_warps = 4
|
38 |
-
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
|
39 |
# launch semi-persistent kernel
|
40 |
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
41 |
num_sms = target_info.num_sms()
|
|
|
35 |
# optimization hyperparameters
|
36 |
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
37 |
num_warps = 4
|
38 |
+
kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {}
|
39 |
# launch semi-persistent kernel
|
40 |
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
41 |
num_sms = target_info.num_sms()
|
torch-ext/triton_kernels/target_info.py
CHANGED
@@ -1,54 +1,70 @@
|
|
1 |
import torch
|
2 |
import triton
|
3 |
|
4 |
-
|
|
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
|
|
|
|
7 |
def is_cuda():
|
8 |
-
|
9 |
-
|
10 |
-
cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
|
11 |
-
return cached_capabilities["is_cuda"]
|
12 |
|
13 |
|
|
|
14 |
def is_hip():
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
|
|
|
|
|
|
|
|
|
19 |
|
|
|
|
|
20 |
def is_hip_cdna3():
|
21 |
-
|
22 |
-
|
23 |
-
cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
|
24 |
-
and target.arch == 'gfx942')
|
25 |
-
return cached_capabilities["is_hip_cdna3"]
|
26 |
|
27 |
|
|
|
28 |
def is_hip_cdna4():
|
29 |
-
|
30 |
-
|
31 |
-
cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
|
32 |
-
and target.arch == 'gfx950')
|
33 |
-
return cached_capabilities["is_hip_cdna4"]
|
34 |
|
35 |
|
|
|
36 |
def cuda_capability_geq(major, minor=0):
|
37 |
"""
|
38 |
Determines whether we have compute capability >= (major, minor) and
|
39 |
returns this as a constexpr boolean. This can be used for guarding
|
40 |
inline asm implementations that require a certain compute capability.
|
41 |
"""
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
return False
|
44 |
-
|
45 |
-
|
46 |
-
cached_capabilities["cuda"] = torch.cuda.get_device_capability()
|
47 |
-
else:
|
48 |
-
cached_capabilities["cuda"] = (0, 0)
|
49 |
-
return cached_capabilities["cuda"] >= (major, minor)
|
50 |
|
51 |
|
|
|
52 |
def get_cdna_version():
|
53 |
"""
|
54 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
@@ -65,13 +81,18 @@ def get_cdna_version():
|
|
65 |
return -1
|
66 |
|
67 |
|
|
|
68 |
def has_tma_gather():
|
69 |
return cuda_capability_geq(10, 0)
|
70 |
|
71 |
|
|
|
72 |
def has_native_mxfp():
|
73 |
return cuda_capability_geq(10, 0)
|
74 |
|
75 |
|
76 |
def num_sms():
|
77 |
-
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import triton
|
3 |
|
4 |
+
from .matmul_ogs_details._common import _constexpr_function
|
5 |
+
from triton.runtime import driver
|
6 |
|
7 |
+
def current_target():
|
8 |
+
try:
|
9 |
+
active_driver = driver.active
|
10 |
+
except RuntimeError:
|
11 |
+
# If there is no active driver, return None
|
12 |
+
return None
|
13 |
+
return active_driver.get_current_target()
|
14 |
|
15 |
+
current_target.__triton_builtin__ = True
|
16 |
+
|
17 |
+
|
18 |
+
@_constexpr_function
|
19 |
def is_cuda():
|
20 |
+
target = current_target()
|
21 |
+
return target is not None and target.backend == "cuda"
|
|
|
|
|
22 |
|
23 |
|
24 |
+
@_constexpr_function
|
25 |
def is_hip():
|
26 |
+
target = current_target()
|
27 |
+
return target is not None and target.backend == "hip"
|
28 |
+
|
29 |
|
30 |
+
@_constexpr_function
|
31 |
+
def is_xpu():
|
32 |
+
target = current_target()
|
33 |
+
return target is not None and target.backend == "xpu"
|
34 |
|
35 |
+
|
36 |
+
@_constexpr_function
|
37 |
def is_hip_cdna3():
|
38 |
+
target = current_target()
|
39 |
+
return target is not None and target.arch == "gfx942"
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
+
@_constexpr_function
|
43 |
def is_hip_cdna4():
|
44 |
+
target = current_target()
|
45 |
+
return target is not None and target.arch == "gfx950"
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
+
@_constexpr_function
|
49 |
def cuda_capability_geq(major, minor=0):
|
50 |
"""
|
51 |
Determines whether we have compute capability >= (major, minor) and
|
52 |
returns this as a constexpr boolean. This can be used for guarding
|
53 |
inline asm implementations that require a certain compute capability.
|
54 |
"""
|
55 |
+
"""
|
56 |
+
Determines whether we have compute capability >= (major, minor) and
|
57 |
+
returns this as a constexpr boolean. This can be used for guarding
|
58 |
+
inline asm implementations that require a certain compute capability.
|
59 |
+
"""
|
60 |
+
target = current_target()
|
61 |
+
if target is None or target.backend != "cuda":
|
62 |
return False
|
63 |
+
assert isinstance(target.arch, int)
|
64 |
+
return target.arch >= major * 10 + minor
|
|
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
+
@_constexpr_function
|
68 |
def get_cdna_version():
|
69 |
"""
|
70 |
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
|
|
81 |
return -1
|
82 |
|
83 |
|
84 |
+
@_constexpr_function
|
85 |
def has_tma_gather():
|
86 |
return cuda_capability_geq(10, 0)
|
87 |
|
88 |
|
89 |
+
@_constexpr_function
|
90 |
def has_native_mxfp():
|
91 |
return cuda_capability_geq(10, 0)
|
92 |
|
93 |
|
94 |
def num_sms():
|
95 |
+
if is_cuda():
|
96 |
+
return torch.cuda.get_device_properties(0).multi_processor_count
|
97 |
+
if is_xpu():
|
98 |
+
return torch.xpu.get_device_properties(0).max_compute_units
|