kernel
build/torch-universal/rotary/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .triton_rotary import apply_rotary
2
+
3
+ __all__ = ["apply_rotary"]
build/torch-universal/rotary/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._rotary_202507301320
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_rotary_202507301320::{op_name}"
build/torch-universal/rotary/triton_rotary.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ @triton.jit
6
+ def _rotary_kernel(
7
+ X1_ptr, X2_ptr, COS_ptr, SIN_ptr, OUT1_ptr, OUT2_ptr,
8
+ stride_x1_b, stride_x1_s, stride_x1_h, stride_x1_d,
9
+ stride_x2_b, stride_x2_s, stride_x2_h, stride_x2_d,
10
+ stride_cos_s, stride_cos_d,
11
+ stride_sin_s, stride_sin_d,
12
+ stride_o1_b, stride_o1_s, stride_o1_h, stride_o1_d,
13
+ stride_o2_b, stride_o2_s, stride_o2_h, stride_o2_d,
14
+ seq_len, num_heads, headdim,
15
+ IS_CONJ: tl.constexpr,
16
+ BLOCK_SIZE_D: tl.constexpr,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_H: tl.constexpr,
19
+ ):
20
+ """
21
+ Triton kernel for applying rotary position embedding.
22
+ """
23
+ # Get program IDs
24
+ pid_b = tl.program_id(0)
25
+ pid_s_block = tl.program_id(1)
26
+ pid_h_block = tl.program_id(2)
27
+
28
+ # Create block pointers
29
+ offs_d = tl.arange(0, BLOCK_SIZE_D)
30
+ offs_s = pid_s_block * BLOCK_M + tl.arange(0, BLOCK_M)
31
+ offs_h = pid_h_block * BLOCK_H + tl.arange(0, BLOCK_H)
32
+
33
+ # Pointers for x1, x2, out1, out2
34
+ x1_ptrs = X1_ptr + pid_b * stride_x1_b + \
35
+ (offs_s[:, None, None] * stride_x1_s + \
36
+ offs_h[None, :, None] * stride_x1_h + \
37
+ offs_d[None, None, :] * stride_x1_d)
38
+
39
+ x2_ptrs = X2_ptr + pid_b * stride_x2_b + \
40
+ (offs_s[:, None, None] * stride_x2_s + \
41
+ offs_h[None, :, None] * stride_x2_h + \
42
+ offs_d[None, None, :] * stride_x2_d)
43
+
44
+ o1_ptrs = OUT1_ptr + pid_b * stride_o1_b + \
45
+ (offs_s[:, None, None] * stride_o1_s + \
46
+ offs_h[None, :, None] * stride_o1_h + \
47
+ offs_d[None, None, :] * stride_o1_d)
48
+
49
+ o2_ptrs = OUT2_ptr + pid_b * stride_o2_b + \
50
+ (offs_s[:, None, None] * stride_o2_s + \
51
+ offs_h[None, :, None] * stride_o2_h + \
52
+ offs_d[None, None, :] * stride_o2_d)
53
+
54
+ # Pointers for cos, sin
55
+ cos_ptrs = COS_ptr + \
56
+ (offs_s[:, None, None] * stride_cos_s + \
57
+ offs_d[None, None, :] * stride_cos_d)
58
+ sin_ptrs = SIN_ptr + \
59
+ (offs_s[:, None, None] * stride_sin_s + \
60
+ offs_d[None, None, :] * stride_sin_d)
61
+
62
+ # Create mask for the last block if dimensions are not multiples of block sizes
63
+ mask_s = offs_s < seq_len
64
+ mask_h = offs_h < num_heads
65
+ mask_d = offs_d < headdim
66
+
67
+ # Combined mask for all tensors: [BLOCK_M, BLOCK_H, BLOCK_SIZE_D]
68
+ mask = mask_s[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :]
69
+ mask_cs = mask_s[:, None, None] & mask_d[None, None, :]
70
+
71
+ # Load data
72
+ x1 = tl.load(x1_ptrs, mask=mask, other=0.0).to(tl.float32)
73
+ x2 = tl.load(x2_ptrs, mask=mask, other=0.0).to(tl.float32)
74
+ cos = tl.load(cos_ptrs, mask=mask_cs, other=0.0).to(tl.float32)
75
+ sin = tl.load(sin_ptrs, mask=mask_cs, other=0.0).to(tl.float32)
76
+
77
+ # Perform rotary transformation
78
+ if IS_CONJ:
79
+ out1 = x1 * cos + x2 * sin
80
+ out2 = -x1 * sin + x2 * cos
81
+ else:
82
+ out1 = x1 * cos - x2 * sin
83
+ out2 = x1 * sin + x2 * cos
84
+
85
+ # Store results
86
+ tl.store(o1_ptrs, out1, mask=mask)
87
+ tl.store(o2_ptrs, out2, mask=mask)
88
+
89
+
90
+ def apply_rotary(
91
+ x1: torch.Tensor,
92
+ x2: torch.Tensor,
93
+ cos: torch.Tensor,
94
+ sin: torch.Tensor,
95
+ out1: torch.Tensor,
96
+ out2: torch.Tensor,
97
+ conj: bool = False
98
+ ):
99
+ """
100
+ Applies rotary position embedding to the input tensors.
101
+
102
+ Args:
103
+ x1, x2: Input tensors. Shape [batch_size, seq_len, num_heads, headdim] or [num_tokens, num_heads, headdim]
104
+ cos, sin: Tensors with cosine and sine values. Shape [num_tokens, 1, rotary_dim]
105
+ out1, out2: Output tensors. Can be the same as x1, x2 for in-place operation.
106
+ conj: If True, applies the conjugate transformation.
107
+ """
108
+ # Shape checks
109
+ assert x1.shape == x2.shape and out1.shape == out2.shape and x1.shape == out1.shape
110
+ assert cos.shape == sin.shape
111
+ assert cos.dim() == 3
112
+ assert x1.device == x2.device == cos.device == sin.device == out1.device == out2.device
113
+
114
+ # Reshape to 4D if necessary
115
+ if x1.dim() == 3: # (num_tokens, num_heads, headdim)
116
+ x1, x2 = x1.unsqueeze(0), x2.unsqueeze(0)
117
+ out1, out2 = out1.unsqueeze(0), out2.unsqueeze(0)
118
+ elif x1.dim() != 4:
119
+ raise ValueError("Input tensors must be 3D or 4D")
120
+
121
+ batch_size, seq_len, num_heads, headdim = x1.shape
122
+
123
+ # Triton grid
124
+ BLOCK_M = 8 if headdim <= 128 else 4
125
+ BLOCK_H = 2
126
+ grid = (batch_size, triton.cdiv(seq_len, BLOCK_M), triton.cdiv(num_heads, BLOCK_H))
127
+
128
+ # Use the smallest power of 2 that is >= headdim as BLOCK_SIZE_D
129
+ BLOCK_SIZE_D = triton.next_power_of_2(headdim)
130
+
131
+ _rotary_kernel[grid](
132
+ x1, x2, cos, sin, out1, out2,
133
+ x1.stride(0), x1.stride(1), x1.stride(2), x1.stride(3),
134
+ x2.stride(0), x2.stride(1), x2.stride(2), x2.stride(3),
135
+ cos.stride(0), cos.stride(2),
136
+ sin.stride(0), sin.stride(2),
137
+ out1.stride(0), out1.stride(1), out1.stride(2), out1.stride(3),
138
+ out2.stride(0), out2.stride(1), out2.stride(2), out2.stride(3),
139
+ seq_len, num_heads, headdim,
140
+ IS_CONJ=conj,
141
+ BLOCK_SIZE_D=BLOCK_SIZE_D,
142
+ BLOCK_M=BLOCK_M,
143
+ BLOCK_H=BLOCK_H,
144
+ )
tests/__init__.py ADDED
File without changes
tests/test_rotary.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+
4
+ from tests.utils import infer_device, supports_bfloat16
5
+ from kernels import get_local_kernel
6
+ from pathlib import Path
7
+
8
+ # from transformers.trainer_utils import set_seed
9
+ # set_seed(42)
10
+
11
+ # Set the local repo path, relative path
12
+ repo_path = Path(__file__).parent.parent
13
+
14
+ def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
15
+ assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
16
+
17
+ if not conj:
18
+ out1 = x1 * cos - x2 * sin
19
+ out2 = x1 * sin + x2 * cos
20
+ else:
21
+ out1 = x1 * cos + x2 * sin
22
+ out2 = -x1 * sin + x2 * cos
23
+ return out1, out2
24
+
25
+
26
+ def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False):
27
+ """the wrapper for apply_rotary_torch"""
28
+ rotary_dim = cos.shape[-1]
29
+
30
+ # apply rotation encoding to Q
31
+ q1 = q[..., :rotary_dim]
32
+ q2 = q[..., rotary_dim : 2 * rotary_dim]
33
+ q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj)
34
+ q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1)
35
+
36
+ # apply rotation encoding to K
37
+ k1 = k[..., :rotary_dim]
38
+ k2 = k[..., rotary_dim : 2 * rotary_dim]
39
+ k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj)
40
+ k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1)
41
+
42
+ return q_out, k_out
43
+
44
+
45
+ def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False):
46
+ """the wrapper for apply_rotary_kernel"""
47
+ rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
48
+ rotary_dim = cos.shape[-1]
49
+
50
+ # apply rotation encoding to Q
51
+ q1 = q[..., :rotary_dim]
52
+ q2 = q[..., rotary_dim : 2 * rotary_dim]
53
+ rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj)
54
+
55
+ # apply rotation encoding to K
56
+ k1 = k[..., :rotary_dim]
57
+ k2 = k[..., rotary_dim : 2 * rotary_dim]
58
+ rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj)
59
+
60
+
61
+ @pytest.mark.parametrize("batch_size", [1, 2])
62
+ @pytest.mark.parametrize("nheads", [8, 16])
63
+ @pytest.mark.parametrize("seqlen", [128, 256])
64
+ @pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)])
65
+ @pytest.mark.parametrize("qk_dim", [3, 4])
66
+ @pytest.mark.parametrize(
67
+ "dtype, atol, rtol",
68
+ [
69
+ (torch.float32, 1e-5, 1e-5),
70
+ pytest.param(
71
+ torch.bfloat16,
72
+ 1e-1,
73
+ 1e-5,
74
+ marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
75
+ ),
76
+ ],
77
+ )
78
+ @pytest.mark.parametrize("conj", [False, True])
79
+ @pytest.mark.flaky(max_runs=2, min_passes=1)
80
+ def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj):
81
+ device = infer_device()
82
+ if device is None:
83
+ pytest.skip("No suitable device found for testing")
84
+
85
+ if qk_dim == 4:
86
+ q_shape = (batch_size, seqlen, nheads, headdim)
87
+ cos_sin_shape = (seqlen, 1, rotary_dim)
88
+ elif qk_dim == 3:
89
+ q_shape = (batch_size * seqlen, nheads, headdim)
90
+ cos_sin_shape = (batch_size * seqlen, 1, rotary_dim)
91
+
92
+ q_orig = torch.randn(q_shape, device=device, dtype=dtype)
93
+ k_orig = torch.randn(q_shape, device=device, dtype=dtype)
94
+ cos = torch.randn(cos_sin_shape, device=device, dtype=dtype)
95
+ sin = torch.randn(cos_sin_shape, device=device, dtype=dtype)
96
+
97
+ q_kernel, k_kernel = q_orig.clone(), k_orig.clone()
98
+ q_torch, k_torch = q_orig.clone(), k_orig.clone()
99
+
100
+ q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj)
101
+ apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj)
102
+
103
+ # verify the rotation results of Q and K are consistent
104
+ try:
105
+ assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match"
106
+ except AssertionError:
107
+ diff_q = torch.abs(q_torch_out - q_kernel)
108
+ max_diff_q = torch.max(diff_q)
109
+ print(f"Max difference for Q: {max_diff_q}")
110
+ raise
111
+ try:
112
+ assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match"
113
+ except AssertionError:
114
+ diff_k = torch.abs(k_torch_out - k_kernel)
115
+ max_diff_k = torch.max(diff_k)
116
+ print(f"Max difference for K: {max_diff_k}")
117
+ raise
118
+
119
+ # verify the non-rotated part of Q and K remains unchanged
120
+ if (2 * rotary_dim) < headdim:
121
+ assert torch.equal(
122
+ q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:]
123
+ ), "Non-rotated part of Q should be unchanged"
124
+ assert torch.equal(
125
+ k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:]
126
+ ), "Non-rotated part of K should be unchanged"
tests/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available(): # Works for both Nvidia and AMD
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ else:
13
+ return None
14
+
15
+
16
+ def supports_bfloat16():
17
+ device = infer_device()
18
+ if device == "cuda":
19
+ return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
20
+ elif device == "xpu":
21
+ return True
22
+ else:
23
+ return False