Implement make_opt_flags function for XPU

#5
tests/conftest.py CHANGED
@@ -1,5 +1,5 @@
1
  import pytest
2
-
3
 
4
  def pytest_addoption(parser):
5
  parser.addoption("--device", action="store", default="cuda")
@@ -12,8 +12,19 @@ def device(request):
12
 
13
  @pytest.fixture
14
  def fresh_knobs(monkeypatch):
 
 
 
 
 
 
 
15
  from triton._internal_testing import _fresh_knobs_impl
16
- fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
 
 
 
 
17
  try:
18
  yield fresh_function()
19
  finally:
 
1
  import pytest
2
+ import triton
3
 
4
  def pytest_addoption(parser):
5
  parser.addoption("--device", action="store", default="cuda")
 
12
 
13
  @pytest.fixture
14
  def fresh_knobs(monkeypatch):
15
+ try:
16
+ _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
17
+ _parts = _ver_str.split(".")
18
+ _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
19
+ except Exception:
20
+ _ver_tuple = (0, 0, 0)
21
+
22
  from triton._internal_testing import _fresh_knobs_impl
23
+ if _ver_tuple > (3, 4, 0):
24
+ fresh_function, reset_function = _fresh_knobs_impl()
25
+ else:
26
+ fresh_function, reset_function = _fresh_knobs_impl(monkeypatch)
27
+
28
  try:
29
  yield fresh_function()
30
  finally:
tests/test_matmul.py CHANGED
@@ -20,7 +20,7 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
20
  # testing utilities
21
  from triton_kernels.testing import assert_close, compute_actual_scale
22
  # target-specific utilities
23
- from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
24
 
25
  # ---------------
26
  # initialize data
@@ -70,7 +70,7 @@ def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_sh
70
  if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
71
  gs0 = None
72
  gs1 = None
73
- if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
74
  w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
75
  return x, w, bias, gs0, gs1
76
 
@@ -291,14 +291,15 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
291
  if hbm_swizzling:
292
  if is_hip():
293
  pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
294
- if torch.cuda.get_device_capability()[0] < 9:
295
- pytest.skip("NYI. Ampere swizzling.")
296
- if torch.cuda.get_device_capability()[0] < 10:
297
- if "mxfloat4" not in weight_dtype_str:
298
- pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
299
- if k % 64 != 0 or n % 64 != 0:
300
- # Automatic padding not implemented for Hopper swizzle
301
- pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
 
302
 
303
  # launch metadata for batched / mx types may not work yet.
304
  test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
@@ -306,7 +307,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
306
  torch.manual_seed(0)
307
 
308
  block_k = None
309
- if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
310
  # Override block_k for testing correctness. The default is temporarily 128 for
311
  # performance reasons which doesn't work with persistent matmul.
312
  # TODO: revisit when Triton is better for H100 + MXFP4
@@ -462,7 +463,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
462
 
463
  round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
464
  ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
465
- rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
466
  scale = lambda val, scal: val if scal is None else val / scal
467
  if n_expt_shards > 1:
468
  if do_scatter:
 
20
  # testing utilities
21
  from triton_kernels.testing import assert_close, compute_actual_scale
22
  # target-specific utilities
23
+ from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
24
 
25
  # ---------------
26
  # initialize data
 
70
  if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
71
  gs0 = None
72
  gs1 = None
73
+ if is_cuda() and "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
74
  w = w.transpose(-1, -2).contiguous().transpose(-1, -2)
75
  return x, w, bias, gs0, gs1
76
 
 
291
  if hbm_swizzling:
292
  if is_hip():
293
  pytest.skip("NYI. HBM swizzling just implemented for CUDA.")
294
+ if is_cuda():
295
+ if torch.cuda.get_device_capability()[0] < 9:
296
+ pytest.skip("NYI. Ampere swizzling.")
297
+ if torch.cuda.get_device_capability()[0] < 10:
298
+ if "mxfloat4" not in weight_dtype_str:
299
+ pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.")
300
+ if k % 64 != 0 or n % 64 != 0:
301
+ # Automatic padding not implemented for Hopper swizzle
302
+ pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
303
 
304
  # launch metadata for batched / mx types may not work yet.
305
  test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
 
307
  torch.manual_seed(0)
308
 
309
  block_k = None
310
+ if is_cuda() and is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
311
  # Override block_k for testing correctness. The default is temporarily 128 for
312
  # performance reasons which doesn't work with persistent matmul.
313
  # TODO: revisit when Triton is better for H100 + MXFP4
 
463
 
464
  round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y
465
  ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, #
466
+ rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref, device=device)
467
  scale = lambda val, scal: val if scal is None else val / scal
468
  if n_expt_shards > 1:
469
  if do_scatter:
torch-ext/triton_kernels/matmul_ogs.py CHANGED
@@ -602,6 +602,7 @@ def matmul_ogs_torch(x, w, bias,
602
  betas = None,
603
  gammas = None,
604
  round_x = None, round_y = None,
 
605
  ):
606
  is_input_batched = x.ndim == 3
607
  assert x.dtype.itemsize > 1
@@ -641,7 +642,7 @@ def matmul_ogs_torch(x, w, bias,
641
  else:
642
  idx = gather_indx.src_indx[lo:hi] // n_expts_act
643
  batch = i if is_input_batched else 0
644
- out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
645
  w[i].float())
646
  if bias is not None:
647
  out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
 
602
  betas = None,
603
  gammas = None,
604
  round_x = None, round_y = None,
605
+ device: str = "cuda",
606
  ):
607
  is_input_batched = x.ndim == 3
608
  assert x.dtype.itemsize > 1
 
642
  else:
643
  idx = gather_indx.src_indx[lo:hi] // n_expts_act
644
  batch = i if is_input_batched else 0
645
+ out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device=device)).float(),
646
  w[i].float())
647
  if bias is not None:
648
  out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
torch-ext/triton_kernels/matmul_ogs_details/_common.py CHANGED
@@ -7,9 +7,21 @@ from triton.tools.tensor_descriptor import TensorDescriptor
7
  # -----------------------------------------------------------------------------
8
  # Utilities
9
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
10
 
 
 
 
 
11
 
12
- @tl.constexpr_function
 
 
13
  def get_scaled_dot_format_string(dtype: tl.dtype):
14
  mapping = {
15
  tl.float16: "fp16",
 
7
  # -----------------------------------------------------------------------------
8
  # Utilities
9
  # -----------------------------------------------------------------------------
10
+ try:
11
+ _ver_str = getattr(triton, "__version__", "0.0.0").split("+")[0]
12
+ _parts = _ver_str.split(".")
13
+ _ver_tuple = tuple(int(p) for p in (_parts + ["0", "0", "0"])[:3])
14
+ except Exception:
15
+ _ver_tuple = (0, 0, 0)
16
 
17
+ if _ver_tuple > (3, 4, 0) and hasattr(triton, "constexpr_function"):
18
+ _constexpr_function = triton.constexpr_function
19
+ else:
20
+ _constexpr_function = tl.constexpr_function
21
 
22
+
23
+
24
+ @_constexpr_function
25
  def get_scaled_dot_format_string(dtype: tl.dtype):
26
  mapping = {
27
  tl.float16: "fp16",
torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py CHANGED
@@ -4,25 +4,26 @@ from ..numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
  from ..target_info import cuda_capability_geq as _cuda_capability_geq
6
  from ..target_info import is_hip as _is_hip
 
7
 
8
 
9
  # fmt: off
10
- @tl.constexpr_function
11
  def is_hip():
12
  return _is_hip()
13
 
14
 
15
- @tl.constexpr_function
16
  def cuda_capability_geq(x, y):
17
  return _cuda_capability_geq(x, y)
18
 
19
 
20
- @tl.constexpr_function
21
  def log2(n):
22
  return len(bin(n)) - 3
23
 
24
 
25
- @tl.constexpr_function
26
  def _permute_to_end_order(n: int, axis: int):
27
  """
28
  Returns the order of the axes of a tensor to permute `axis` to the end.
@@ -105,7 +106,7 @@ def _finalize_matmul_launch_metadata(grid, kernel, args):
105
  return ret
106
 
107
 
108
- @tl.constexpr_function
109
  def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
110
  """
111
  Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
 
4
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
  from ..target_info import cuda_capability_geq as _cuda_capability_geq
6
  from ..target_info import is_hip as _is_hip
7
+ from ._common import _constexpr_function
8
 
9
 
10
  # fmt: off
11
+ @_constexpr_function
12
  def is_hip():
13
  return _is_hip()
14
 
15
 
16
+ @_constexpr_function
17
  def cuda_capability_geq(x, y):
18
  return _cuda_capability_geq(x, y)
19
 
20
 
21
+ @_constexpr_function
22
  def log2(n):
23
  return len(bin(n)) - 3
24
 
25
 
26
+ @_constexpr_function
27
  def _permute_to_end_order(n: int, axis: int):
28
  """
29
  Returns the order of the axes of a tensor to permute `axis` to the end.
 
106
  return ret
107
 
108
 
109
+ @_constexpr_function
110
  def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
111
  """
112
  Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py CHANGED
@@ -12,14 +12,14 @@ from ..numerics_details.flexpoint import (
12
  compute_scale,
13
  )
14
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
- from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
 
17
 
18
- @tl.constexpr_function
19
  def cuda_capability_geq(major, minor):
20
  return target_info.cuda_capability_geq(major, minor)
21
 
22
- @tl.constexpr_function
23
  def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
24
  if isinstance(tensor_or_desc, tl.tensor):
25
  return tensor_or_desc.dtype.element_ty
 
12
  compute_scale,
13
  )
14
  from ..numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
+ from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string, _constexpr_function
16
 
17
 
18
+ @_constexpr_function
19
  def cuda_capability_geq(major, minor):
20
  return target_info.cuda_capability_geq(major, minor)
21
 
22
+ @_constexpr_function
23
  def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
24
  if isinstance(tensor_or_desc, tl.tensor):
25
  return tensor_or_desc.dtype.element_ty
torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py CHANGED
@@ -4,7 +4,7 @@ from dataclasses import dataclass
4
  import triton
5
  from ..target_info import get_cdna_version
6
  import torch
7
- from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
 
9
 
10
  @dataclass
@@ -30,6 +30,83 @@ class OptFlags:
30
  raise ValueError("Not supported")
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def make_default_opt_flags_amd(
35
  out_dtype,
@@ -292,6 +369,8 @@ def make_opt_flags(
292
  enforce_bitwise_invariance, epilogue_effective_itemsize,
293
  _opt_flags_constraints]
294
  backend = triton.runtime.driver.active.get_current_target().backend
 
 
295
  if backend == "hip":
296
  return make_default_opt_flags_amd(*args)
297
  if backend == "cuda":
 
4
  import triton
5
  from ..target_info import get_cdna_version
6
  import torch
7
+ from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel
8
 
9
 
10
  @dataclass
 
30
  raise ValueError("Not supported")
31
 
32
 
33
+ def make_default_opt_flags_intel(
34
+ out_dtype,
35
+ lhs_dtype,
36
+ rhs_dtype,
37
+ precision_config,
38
+ m,
39
+ n,
40
+ k,
41
+ routing_data,
42
+ can_use_persistent_tma,
43
+ can_use_fused_scatter,
44
+ enforce_bitwise_invariance,
45
+ epilogue_effective_itemsize,
46
+ constraints,
47
+ ):
48
+ constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages"]
49
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
50
+ # tokens per expert
51
+ if routing_data is None:
52
+ tokens_per_expt = m
53
+ elif routing_data.expected_tokens_per_expt is None:
54
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
55
+ else:
56
+ tokens_per_expt = routing_data.expected_tokens_per_expt
57
+ # pid swizzling
58
+ group_m = 8
59
+ xcd_swizzle = 1
60
+ # block_m
61
+ if constraints.get("block_m", None):
62
+ block_m = constraints["block_m"]
63
+ elif enforce_bitwise_invariance:
64
+ block_m = 128
65
+ else:
66
+ block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
67
+ # block n
68
+ block_n = opt_flags_intel.compute_block_n(n)
69
+ # is_persistent
70
+ is_persistent = constraints.get("is_persistent", False)
71
+ # block k
72
+ if constraints.get("block_k", None) is not None:
73
+ block_k = constraints["block_k"]
74
+ else:
75
+ block_k = opt_flags_intel.compute_block_k(k, is_persistent, precision_config)
76
+ # split_k
77
+ if constraints.get("split_k", None) is not None:
78
+ split_k = constraints["split_k"]
79
+ elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
80
+ split_k = 1
81
+ else:
82
+ estimated_actual_grid_size = opt_flags_intel.compute_grid_size(None, m, n, block_m, block_n)
83
+ split_k = opt_flags_intel.compute_split_k(block_k, k, estimated_actual_grid_size)
84
+
85
+ epilogue_subtile = constraints.get('epilogue_subtile', None)
86
+ if epilogue_subtile is None:
87
+ epilogue_subtile = 1
88
+
89
+ ret = OptFlags(
90
+ block_m=block_m,
91
+ block_n=block_n,
92
+ block_k=block_k,
93
+ num_warps=opt_flags_intel.compute_num_warps(block_m, block_n),
94
+ num_stages=constraints.get("num_stages", 2),
95
+ fused_scatter=constraints.get('fused_scatter', False),
96
+ group_m=group_m,
97
+ xcd_swizzle=xcd_swizzle,
98
+ w_cache_modifier=None,
99
+ split_k=split_k,
100
+ is_persistent=is_persistent,
101
+ epilogue_subtile=epilogue_subtile,
102
+ arch=None,
103
+ target_kernel_kwargs=dict(),
104
+ idle_sms=0,
105
+ )
106
+ # check constraints
107
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
108
+ return ret
109
+
110
 
111
  def make_default_opt_flags_amd(
112
  out_dtype,
 
369
  enforce_bitwise_invariance, epilogue_effective_itemsize,
370
  _opt_flags_constraints]
371
  backend = triton.runtime.driver.active.get_current_target().backend
372
+ if backend == "xpu":
373
+ return make_default_opt_flags_intel(*args)
374
  if backend == "hip":
375
  return make_default_opt_flags_amd(*args)
376
  if backend == "cuda":
torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_intel.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+
5
+ def compute_grid_size(routing_data, m, n, block_m, block_n):
6
+ if routing_data is not None:
7
+ grid_m = routing_data.n_blocks(m, block_m)
8
+ else:
9
+ grid_m = triton.cdiv(m, block_m)
10
+ grid_n = (n + block_n - 1) // block_n
11
+ return grid_m * grid_n
12
+
13
+
14
+ def compute_block_n(n: int):
15
+ # block_n:
16
+ return max(16, min(128, triton.next_power_of_2(n)))
17
+
18
+
19
+ def compute_block_k(k: int | None, is_persistent: bool, precision_config):
20
+ if k is not None:
21
+ block_k = max(32, min(128, triton.next_power_of_2(k)))
22
+ has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
23
+ if is_persistent and has_mx_weight_scale:
24
+ block_k = min(block_k, 128)
25
+ return block_k
26
+
27
+
28
+ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
29
+ device_props = torch.xpu.get_device_properties(0)
30
+ n_sms = device_props.gpu_subslice_count
31
+ split_k = n_sms // grid_size
32
+ if k is not None:
33
+ # avoid split_k for small k
34
+ num_block_k = triton.cdiv(k, block_k)
35
+ split_k = min(split_k, num_block_k // 4)
36
+ split_k = max(split_k, 1)
37
+ return split_k
38
+
39
+
40
+ def compute_num_warps(block_m, block_n):
41
+ return max(block_m * block_n // 4096, 4)
torch-ext/triton_kernels/numerics_details/flexpoint.py CHANGED
@@ -1,5 +1,6 @@
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
  from .. import target_info
 
3
  import triton
4
  import triton.language as tl
5
 
@@ -52,7 +53,7 @@ def rcp_max_finite(dtype):
52
  tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
53
 
54
 
55
- @tl.constexpr_function
56
  def cuda_capability_geq(major, minor):
57
  return target_info.cuda_capability_geq(major, minor)
58
 
 
1
  from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
  from .. import target_info
3
+ from ..matmul_ogs_details._common import _constexpr_function
4
  import triton
5
  import triton.language as tl
6
 
 
53
  tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
54
 
55
 
56
+ @_constexpr_function
57
  def cuda_capability_geq(major, minor):
58
  return target_info.cuda_capability_geq(major, minor)
59
 
torch-ext/triton_kernels/swiglu.py CHANGED
@@ -35,7 +35,7 @@ class SwiGLU(torch.autograd.Function):
35
  # optimization hyperparameters
36
  BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
37
  num_warps = 4
38
- kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
39
  # launch semi-persistent kernel
40
  N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
41
  num_sms = target_info.num_sms()
 
35
  # optimization hyperparameters
36
  BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
37
  num_warps = 4
38
+ kwargs = {'maxnreg': 64} if not target_info.is_hip() and not target_info.is_xpu() else {}
39
  # launch semi-persistent kernel
40
  N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
41
  num_sms = target_info.num_sms()
torch-ext/triton_kernels/target_info.py CHANGED
@@ -1,54 +1,70 @@
1
  import torch
2
  import triton
3
 
4
- cached_capabilities = {}
 
5
 
 
 
 
 
 
 
 
6
 
 
 
 
 
7
  def is_cuda():
8
- if "is_cuda" not in cached_capabilities:
9
- target = triton.runtime.driver.active.get_current_target()
10
- cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
11
- return cached_capabilities["is_cuda"]
12
 
13
 
 
14
  def is_hip():
15
- if "is_hip" not in cached_capabilities:
16
- cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip)
17
- return cached_capabilities["is_hip"]
18
 
 
 
 
 
19
 
 
 
20
  def is_hip_cdna3():
21
- if "is_hip_cdna3" not in cached_capabilities:
22
- target = triton.runtime.driver.active.get_current_target()
23
- cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
24
- and target.arch == 'gfx942')
25
- return cached_capabilities["is_hip_cdna3"]
26
 
27
 
 
28
  def is_hip_cdna4():
29
- if "is_hip_cdna4" not in cached_capabilities:
30
- target = triton.runtime.driver.active.get_current_target()
31
- cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
32
- and target.arch == 'gfx950')
33
- return cached_capabilities["is_hip_cdna4"]
34
 
35
 
 
36
  def cuda_capability_geq(major, minor=0):
37
  """
38
  Determines whether we have compute capability >= (major, minor) and
39
  returns this as a constexpr boolean. This can be used for guarding
40
  inline asm implementations that require a certain compute capability.
41
  """
42
- if is_hip():
 
 
 
 
 
 
43
  return False
44
- if "cuda" not in cached_capabilities:
45
- if torch.cuda.is_available():
46
- cached_capabilities["cuda"] = torch.cuda.get_device_capability()
47
- else:
48
- cached_capabilities["cuda"] = (0, 0)
49
- return cached_capabilities["cuda"] >= (major, minor)
50
 
51
 
 
52
  def get_cdna_version():
53
  """
54
  Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
@@ -65,13 +81,18 @@ def get_cdna_version():
65
  return -1
66
 
67
 
 
68
  def has_tma_gather():
69
  return cuda_capability_geq(10, 0)
70
 
71
 
 
72
  def has_native_mxfp():
73
  return cuda_capability_geq(10, 0)
74
 
75
 
76
  def num_sms():
77
- return torch.cuda.get_device_properties(0).multi_processor_count
 
 
 
 
1
  import torch
2
  import triton
3
 
4
+ from .matmul_ogs_details._common import _constexpr_function
5
+ from triton.runtime import driver
6
 
7
+ def current_target():
8
+ try:
9
+ active_driver = driver.active
10
+ except RuntimeError:
11
+ # If there is no active driver, return None
12
+ return None
13
+ return active_driver.get_current_target()
14
 
15
+ current_target.__triton_builtin__ = True
16
+
17
+
18
+ @_constexpr_function
19
  def is_cuda():
20
+ target = current_target()
21
+ return target is not None and target.backend == "cuda"
 
 
22
 
23
 
24
+ @_constexpr_function
25
  def is_hip():
26
+ target = current_target()
27
+ return target is not None and target.backend == "hip"
28
+
29
 
30
+ @_constexpr_function
31
+ def is_xpu():
32
+ target = current_target()
33
+ return target is not None and target.backend == "xpu"
34
 
35
+
36
+ @_constexpr_function
37
  def is_hip_cdna3():
38
+ target = current_target()
39
+ return target is not None and target.arch == "gfx942"
 
 
 
40
 
41
 
42
+ @_constexpr_function
43
  def is_hip_cdna4():
44
+ target = current_target()
45
+ return target is not None and target.arch == "gfx950"
 
 
 
46
 
47
 
48
+ @_constexpr_function
49
  def cuda_capability_geq(major, minor=0):
50
  """
51
  Determines whether we have compute capability >= (major, minor) and
52
  returns this as a constexpr boolean. This can be used for guarding
53
  inline asm implementations that require a certain compute capability.
54
  """
55
+ """
56
+ Determines whether we have compute capability >= (major, minor) and
57
+ returns this as a constexpr boolean. This can be used for guarding
58
+ inline asm implementations that require a certain compute capability.
59
+ """
60
+ target = current_target()
61
+ if target is None or target.backend != "cuda":
62
  return False
63
+ assert isinstance(target.arch, int)
64
+ return target.arch >= major * 10 + minor
 
 
 
 
65
 
66
 
67
+ @_constexpr_function
68
  def get_cdna_version():
69
  """
70
  Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
 
81
  return -1
82
 
83
 
84
+ @_constexpr_function
85
  def has_tma_gather():
86
  return cuda_capability_geq(10, 0)
87
 
88
 
89
+ @_constexpr_function
90
  def has_native_mxfp():
91
  return cuda_capability_geq(10, 0)
92
 
93
 
94
  def num_sms():
95
+ if is_cuda():
96
+ return torch.cuda.get_device_properties(0).multi_processor_count
97
+ if is_xpu():
98
+ return torch.xpu.get_device_properties(0).max_compute_units