|
import pytest |
|
import torch |
|
|
|
from tests.utils import infer_device, supports_bfloat16 |
|
from kernels import get_local_kernel |
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
repo_path = Path(__file__).parent.parent |
|
|
|
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False): |
|
assert x1.shape == x2.shape, "x1 and x2 must have the same shape" |
|
|
|
if not conj: |
|
out1 = x1 * cos - x2 * sin |
|
out2 = x1 * sin + x2 * cos |
|
else: |
|
out1 = x1 * cos + x2 * sin |
|
out2 = -x1 * sin + x2 * cos |
|
return out1, out2 |
|
|
|
|
|
def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False): |
|
"""the wrapper for apply_rotary_torch""" |
|
rotary_dim = cos.shape[-1] |
|
|
|
|
|
q1 = q[..., :rotary_dim] |
|
q2 = q[..., rotary_dim : 2 * rotary_dim] |
|
q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj) |
|
q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1) |
|
|
|
|
|
k1 = k[..., :rotary_dim] |
|
k2 = k[..., rotary_dim : 2 * rotary_dim] |
|
k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj) |
|
k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1) |
|
|
|
return q_out, k_out |
|
|
|
|
|
def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False): |
|
"""the wrapper for apply_rotary_kernel""" |
|
rotary = get_local_kernel(repo_path=repo_path, package_name="rotary") |
|
rotary_dim = cos.shape[-1] |
|
|
|
|
|
q1 = q[..., :rotary_dim] |
|
q2 = q[..., rotary_dim : 2 * rotary_dim] |
|
rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj) |
|
|
|
|
|
k1 = k[..., :rotary_dim] |
|
k2 = k[..., rotary_dim : 2 * rotary_dim] |
|
rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj) |
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2]) |
|
@pytest.mark.parametrize("nheads", [8, 16]) |
|
@pytest.mark.parametrize("seqlen", [128, 256]) |
|
@pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)]) |
|
@pytest.mark.parametrize("qk_dim", [3, 4]) |
|
@pytest.mark.parametrize( |
|
"dtype, atol, rtol", |
|
[ |
|
(torch.float32, 1e-5, 1e-5), |
|
pytest.param( |
|
torch.bfloat16, |
|
1e-1, |
|
1e-5, |
|
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), |
|
), |
|
], |
|
) |
|
@pytest.mark.parametrize("conj", [False, True]) |
|
@pytest.mark.flaky(max_runs=2, min_passes=1) |
|
def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj): |
|
device = infer_device() |
|
if device is None: |
|
pytest.skip("No suitable device found for testing") |
|
|
|
if qk_dim == 4: |
|
q_shape = (batch_size, seqlen, nheads, headdim) |
|
cos_sin_shape = (seqlen, 1, rotary_dim) |
|
elif qk_dim == 3: |
|
q_shape = (batch_size * seqlen, nheads, headdim) |
|
cos_sin_shape = (batch_size * seqlen, 1, rotary_dim) |
|
|
|
q_orig = torch.randn(q_shape, device=device, dtype=dtype) |
|
k_orig = torch.randn(q_shape, device=device, dtype=dtype) |
|
cos = torch.randn(cos_sin_shape, device=device, dtype=dtype) |
|
sin = torch.randn(cos_sin_shape, device=device, dtype=dtype) |
|
|
|
q_kernel, k_kernel = q_orig.clone(), k_orig.clone() |
|
q_torch, k_torch = q_orig.clone(), k_orig.clone() |
|
|
|
q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj) |
|
apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj) |
|
|
|
|
|
try: |
|
assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match" |
|
except AssertionError: |
|
diff_q = torch.abs(q_torch_out - q_kernel) |
|
max_diff_q = torch.max(diff_q) |
|
print(f"Max difference for Q: {max_diff_q}") |
|
raise |
|
try: |
|
assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match" |
|
except AssertionError: |
|
diff_k = torch.abs(k_torch_out - k_kernel) |
|
max_diff_k = torch.max(diff_k) |
|
print(f"Max difference for K: {max_diff_k}") |
|
raise |
|
|
|
|
|
if (2 * rotary_dim) < headdim: |
|
assert torch.equal( |
|
q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:] |
|
), "Non-rotated part of Q should be unchanged" |
|
assert torch.equal( |
|
k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:] |
|
), "Non-rotated part of K should be unchanged" |
|
|