Add triton support
#2
by
YangKai0616
- opened
- build/torch-universal/rotary/__init__.py +3 -0
- build/torch-universal/rotary/_ops.py +8 -0
- build/torch-universal/rotary/triton_rotary.py +144 -0
- tests/__init__.py +0 -0
- tests/test_rotary.py +126 -0
- tests/utils.py +23 -0
build/torch-universal/rotary/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .triton_rotary import apply_rotary
|
2 |
+
|
3 |
+
__all__ = ["apply_rotary"]
|
build/torch-universal/rotary/_ops.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
ops = torch.ops._rotary_202507301320
|
3 |
+
|
4 |
+
def add_op_namespace_prefix(op_name: str):
|
5 |
+
"""
|
6 |
+
Prefix op by namespace.
|
7 |
+
"""
|
8 |
+
return f"_rotary_202507301320::{op_name}"
|
build/torch-universal/rotary/triton_rotary.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
@triton.jit
|
6 |
+
def _rotary_kernel(
|
7 |
+
X1_ptr, X2_ptr, COS_ptr, SIN_ptr, OUT1_ptr, OUT2_ptr,
|
8 |
+
stride_x1_b, stride_x1_s, stride_x1_h, stride_x1_d,
|
9 |
+
stride_x2_b, stride_x2_s, stride_x2_h, stride_x2_d,
|
10 |
+
stride_cos_s, stride_cos_d,
|
11 |
+
stride_sin_s, stride_sin_d,
|
12 |
+
stride_o1_b, stride_o1_s, stride_o1_h, stride_o1_d,
|
13 |
+
stride_o2_b, stride_o2_s, stride_o2_h, stride_o2_d,
|
14 |
+
seq_len, num_heads, headdim,
|
15 |
+
IS_CONJ: tl.constexpr,
|
16 |
+
BLOCK_SIZE_D: tl.constexpr,
|
17 |
+
BLOCK_M: tl.constexpr,
|
18 |
+
BLOCK_H: tl.constexpr,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Triton kernel for applying rotary position embedding.
|
22 |
+
"""
|
23 |
+
# Get program IDs
|
24 |
+
pid_b = tl.program_id(0)
|
25 |
+
pid_s_block = tl.program_id(1)
|
26 |
+
pid_h_block = tl.program_id(2)
|
27 |
+
|
28 |
+
# Create block pointers
|
29 |
+
offs_d = tl.arange(0, BLOCK_SIZE_D)
|
30 |
+
offs_s = pid_s_block * BLOCK_M + tl.arange(0, BLOCK_M)
|
31 |
+
offs_h = pid_h_block * BLOCK_H + tl.arange(0, BLOCK_H)
|
32 |
+
|
33 |
+
# Pointers for x1, x2, out1, out2
|
34 |
+
x1_ptrs = X1_ptr + pid_b * stride_x1_b + \
|
35 |
+
(offs_s[:, None, None] * stride_x1_s + \
|
36 |
+
offs_h[None, :, None] * stride_x1_h + \
|
37 |
+
offs_d[None, None, :] * stride_x1_d)
|
38 |
+
|
39 |
+
x2_ptrs = X2_ptr + pid_b * stride_x2_b + \
|
40 |
+
(offs_s[:, None, None] * stride_x2_s + \
|
41 |
+
offs_h[None, :, None] * stride_x2_h + \
|
42 |
+
offs_d[None, None, :] * stride_x2_d)
|
43 |
+
|
44 |
+
o1_ptrs = OUT1_ptr + pid_b * stride_o1_b + \
|
45 |
+
(offs_s[:, None, None] * stride_o1_s + \
|
46 |
+
offs_h[None, :, None] * stride_o1_h + \
|
47 |
+
offs_d[None, None, :] * stride_o1_d)
|
48 |
+
|
49 |
+
o2_ptrs = OUT2_ptr + pid_b * stride_o2_b + \
|
50 |
+
(offs_s[:, None, None] * stride_o2_s + \
|
51 |
+
offs_h[None, :, None] * stride_o2_h + \
|
52 |
+
offs_d[None, None, :] * stride_o2_d)
|
53 |
+
|
54 |
+
# Pointers for cos, sin
|
55 |
+
cos_ptrs = COS_ptr + \
|
56 |
+
(offs_s[:, None, None] * stride_cos_s + \
|
57 |
+
offs_d[None, None, :] * stride_cos_d)
|
58 |
+
sin_ptrs = SIN_ptr + \
|
59 |
+
(offs_s[:, None, None] * stride_sin_s + \
|
60 |
+
offs_d[None, None, :] * stride_sin_d)
|
61 |
+
|
62 |
+
# Create mask for the last block if dimensions are not multiples of block sizes
|
63 |
+
mask_s = offs_s < seq_len
|
64 |
+
mask_h = offs_h < num_heads
|
65 |
+
mask_d = offs_d < headdim
|
66 |
+
|
67 |
+
# Combined mask for all tensors: [BLOCK_M, BLOCK_H, BLOCK_SIZE_D]
|
68 |
+
mask = mask_s[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :]
|
69 |
+
mask_cs = mask_s[:, None, None] & mask_d[None, None, :]
|
70 |
+
|
71 |
+
# Load data
|
72 |
+
x1 = tl.load(x1_ptrs, mask=mask, other=0.0).to(tl.float32)
|
73 |
+
x2 = tl.load(x2_ptrs, mask=mask, other=0.0).to(tl.float32)
|
74 |
+
cos = tl.load(cos_ptrs, mask=mask_cs, other=0.0).to(tl.float32)
|
75 |
+
sin = tl.load(sin_ptrs, mask=mask_cs, other=0.0).to(tl.float32)
|
76 |
+
|
77 |
+
# Perform rotary transformation
|
78 |
+
if IS_CONJ:
|
79 |
+
out1 = x1 * cos + x2 * sin
|
80 |
+
out2 = -x1 * sin + x2 * cos
|
81 |
+
else:
|
82 |
+
out1 = x1 * cos - x2 * sin
|
83 |
+
out2 = x1 * sin + x2 * cos
|
84 |
+
|
85 |
+
# Store results
|
86 |
+
tl.store(o1_ptrs, out1, mask=mask)
|
87 |
+
tl.store(o2_ptrs, out2, mask=mask)
|
88 |
+
|
89 |
+
|
90 |
+
def apply_rotary(
|
91 |
+
x1: torch.Tensor,
|
92 |
+
x2: torch.Tensor,
|
93 |
+
cos: torch.Tensor,
|
94 |
+
sin: torch.Tensor,
|
95 |
+
out1: torch.Tensor,
|
96 |
+
out2: torch.Tensor,
|
97 |
+
conj: bool = False
|
98 |
+
):
|
99 |
+
"""
|
100 |
+
Applies rotary position embedding to the input tensors.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
x1, x2: Input tensors. Shape [batch_size, seq_len, num_heads, headdim] or [num_tokens, num_heads, headdim]
|
104 |
+
cos, sin: Tensors with cosine and sine values. Shape [num_tokens, 1, rotary_dim]
|
105 |
+
out1, out2: Output tensors. Can be the same as x1, x2 for in-place operation.
|
106 |
+
conj: If True, applies the conjugate transformation.
|
107 |
+
"""
|
108 |
+
# Shape checks
|
109 |
+
assert x1.shape == x2.shape and out1.shape == out2.shape and x1.shape == out1.shape
|
110 |
+
assert cos.shape == sin.shape
|
111 |
+
assert cos.dim() == 3
|
112 |
+
assert x1.device == x2.device == cos.device == sin.device == out1.device == out2.device
|
113 |
+
|
114 |
+
# Reshape to 4D if necessary
|
115 |
+
if x1.dim() == 3: # (num_tokens, num_heads, headdim)
|
116 |
+
x1, x2 = x1.unsqueeze(0), x2.unsqueeze(0)
|
117 |
+
out1, out2 = out1.unsqueeze(0), out2.unsqueeze(0)
|
118 |
+
elif x1.dim() != 4:
|
119 |
+
raise ValueError("Input tensors must be 3D or 4D")
|
120 |
+
|
121 |
+
batch_size, seq_len, num_heads, headdim = x1.shape
|
122 |
+
|
123 |
+
# Triton grid
|
124 |
+
BLOCK_M = 8 if headdim <= 128 else 4
|
125 |
+
BLOCK_H = 2
|
126 |
+
grid = (batch_size, triton.cdiv(seq_len, BLOCK_M), triton.cdiv(num_heads, BLOCK_H))
|
127 |
+
|
128 |
+
# Use the smallest power of 2 that is >= headdim as BLOCK_SIZE_D
|
129 |
+
BLOCK_SIZE_D = triton.next_power_of_2(headdim)
|
130 |
+
|
131 |
+
_rotary_kernel[grid](
|
132 |
+
x1, x2, cos, sin, out1, out2,
|
133 |
+
x1.stride(0), x1.stride(1), x1.stride(2), x1.stride(3),
|
134 |
+
x2.stride(0), x2.stride(1), x2.stride(2), x2.stride(3),
|
135 |
+
cos.stride(0), cos.stride(2),
|
136 |
+
sin.stride(0), sin.stride(2),
|
137 |
+
out1.stride(0), out1.stride(1), out1.stride(2), out1.stride(3),
|
138 |
+
out2.stride(0), out2.stride(1), out2.stride(2), out2.stride(3),
|
139 |
+
seq_len, num_heads, headdim,
|
140 |
+
IS_CONJ=conj,
|
141 |
+
BLOCK_SIZE_D=BLOCK_SIZE_D,
|
142 |
+
BLOCK_M=BLOCK_M,
|
143 |
+
BLOCK_H=BLOCK_H,
|
144 |
+
)
|
tests/__init__.py
ADDED
File without changes
|
tests/test_rotary.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from tests.utils import infer_device, supports_bfloat16
|
5 |
+
from kernels import get_local_kernel
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# from transformers.trainer_utils import set_seed
|
9 |
+
# set_seed(42)
|
10 |
+
|
11 |
+
# Set the local repo path, relative path
|
12 |
+
repo_path = Path(__file__).parent.parent
|
13 |
+
|
14 |
+
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
|
15 |
+
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
|
16 |
+
|
17 |
+
if not conj:
|
18 |
+
out1 = x1 * cos - x2 * sin
|
19 |
+
out2 = x1 * sin + x2 * cos
|
20 |
+
else:
|
21 |
+
out1 = x1 * cos + x2 * sin
|
22 |
+
out2 = -x1 * sin + x2 * cos
|
23 |
+
return out1, out2
|
24 |
+
|
25 |
+
|
26 |
+
def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False):
|
27 |
+
"""the wrapper for apply_rotary_torch"""
|
28 |
+
rotary_dim = cos.shape[-1]
|
29 |
+
|
30 |
+
# apply rotation encoding to Q
|
31 |
+
q1 = q[..., :rotary_dim]
|
32 |
+
q2 = q[..., rotary_dim : 2 * rotary_dim]
|
33 |
+
q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj)
|
34 |
+
q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1)
|
35 |
+
|
36 |
+
# apply rotation encoding to K
|
37 |
+
k1 = k[..., :rotary_dim]
|
38 |
+
k2 = k[..., rotary_dim : 2 * rotary_dim]
|
39 |
+
k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj)
|
40 |
+
k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1)
|
41 |
+
|
42 |
+
return q_out, k_out
|
43 |
+
|
44 |
+
|
45 |
+
def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False):
|
46 |
+
"""the wrapper for apply_rotary_kernel"""
|
47 |
+
rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
|
48 |
+
rotary_dim = cos.shape[-1]
|
49 |
+
|
50 |
+
# apply rotation encoding to Q
|
51 |
+
q1 = q[..., :rotary_dim]
|
52 |
+
q2 = q[..., rotary_dim : 2 * rotary_dim]
|
53 |
+
rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj)
|
54 |
+
|
55 |
+
# apply rotation encoding to K
|
56 |
+
k1 = k[..., :rotary_dim]
|
57 |
+
k2 = k[..., rotary_dim : 2 * rotary_dim]
|
58 |
+
rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj)
|
59 |
+
|
60 |
+
|
61 |
+
@pytest.mark.parametrize("batch_size", [1, 2])
|
62 |
+
@pytest.mark.parametrize("nheads", [8, 16])
|
63 |
+
@pytest.mark.parametrize("seqlen", [128, 256])
|
64 |
+
@pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)])
|
65 |
+
@pytest.mark.parametrize("qk_dim", [3, 4])
|
66 |
+
@pytest.mark.parametrize(
|
67 |
+
"dtype, atol, rtol",
|
68 |
+
[
|
69 |
+
(torch.float32, 1e-5, 1e-5),
|
70 |
+
pytest.param(
|
71 |
+
torch.bfloat16,
|
72 |
+
1e-1,
|
73 |
+
1e-5,
|
74 |
+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
75 |
+
),
|
76 |
+
],
|
77 |
+
)
|
78 |
+
@pytest.mark.parametrize("conj", [False, True])
|
79 |
+
@pytest.mark.flaky(max_runs=2, min_passes=1)
|
80 |
+
def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj):
|
81 |
+
device = infer_device()
|
82 |
+
if device is None:
|
83 |
+
pytest.skip("No suitable device found for testing")
|
84 |
+
|
85 |
+
if qk_dim == 4:
|
86 |
+
q_shape = (batch_size, seqlen, nheads, headdim)
|
87 |
+
cos_sin_shape = (seqlen, 1, rotary_dim)
|
88 |
+
elif qk_dim == 3:
|
89 |
+
q_shape = (batch_size * seqlen, nheads, headdim)
|
90 |
+
cos_sin_shape = (batch_size * seqlen, 1, rotary_dim)
|
91 |
+
|
92 |
+
q_orig = torch.randn(q_shape, device=device, dtype=dtype)
|
93 |
+
k_orig = torch.randn(q_shape, device=device, dtype=dtype)
|
94 |
+
cos = torch.randn(cos_sin_shape, device=device, dtype=dtype)
|
95 |
+
sin = torch.randn(cos_sin_shape, device=device, dtype=dtype)
|
96 |
+
|
97 |
+
q_kernel, k_kernel = q_orig.clone(), k_orig.clone()
|
98 |
+
q_torch, k_torch = q_orig.clone(), k_orig.clone()
|
99 |
+
|
100 |
+
q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj)
|
101 |
+
apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj)
|
102 |
+
|
103 |
+
# verify the rotation results of Q and K are consistent
|
104 |
+
try:
|
105 |
+
assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match"
|
106 |
+
except AssertionError:
|
107 |
+
diff_q = torch.abs(q_torch_out - q_kernel)
|
108 |
+
max_diff_q = torch.max(diff_q)
|
109 |
+
print(f"Max difference for Q: {max_diff_q}")
|
110 |
+
raise
|
111 |
+
try:
|
112 |
+
assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match"
|
113 |
+
except AssertionError:
|
114 |
+
diff_k = torch.abs(k_torch_out - k_kernel)
|
115 |
+
max_diff_k = torch.max(diff_k)
|
116 |
+
print(f"Max difference for K: {max_diff_k}")
|
117 |
+
raise
|
118 |
+
|
119 |
+
# verify the non-rotated part of Q and K remains unchanged
|
120 |
+
if (2 * rotary_dim) < headdim:
|
121 |
+
assert torch.equal(
|
122 |
+
q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:]
|
123 |
+
), "Non-rotated part of Q should be unchanged"
|
124 |
+
assert torch.equal(
|
125 |
+
k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:]
|
126 |
+
), "Non-rotated part of K should be unchanged"
|
tests/utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def infer_device():
|
5 |
+
"""
|
6 |
+
Get current device name based on available devices
|
7 |
+
"""
|
8 |
+
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
9 |
+
return "cuda"
|
10 |
+
elif torch.xpu.is_available():
|
11 |
+
return "xpu"
|
12 |
+
else:
|
13 |
+
return None
|
14 |
+
|
15 |
+
|
16 |
+
def supports_bfloat16():
|
17 |
+
device = infer_device()
|
18 |
+
if device == "cuda":
|
19 |
+
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
|
20 |
+
elif device == "xpu":
|
21 |
+
return True
|
22 |
+
else:
|
23 |
+
return False
|