File size: 4,701 Bytes
2e98b65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import pytest
import torch
from tests.utils import infer_device, supports_bfloat16
from kernels import get_local_kernel
from pathlib import Path
# from transformers.trainer_utils import set_seed
# set_seed(42)
# Set the local repo path, relative path
repo_path = Path(__file__).parent.parent
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
if not conj:
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
else:
out1 = x1 * cos + x2 * sin
out2 = -x1 * sin + x2 * cos
return out1, out2
def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False):
"""the wrapper for apply_rotary_torch"""
rotary_dim = cos.shape[-1]
# apply rotation encoding to Q
q1 = q[..., :rotary_dim]
q2 = q[..., rotary_dim : 2 * rotary_dim]
q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj)
q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1)
# apply rotation encoding to K
k1 = k[..., :rotary_dim]
k2 = k[..., rotary_dim : 2 * rotary_dim]
k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj)
k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1)
return q_out, k_out
def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False):
"""the wrapper for apply_rotary_kernel"""
rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
rotary_dim = cos.shape[-1]
# apply rotation encoding to Q
q1 = q[..., :rotary_dim]
q2 = q[..., rotary_dim : 2 * rotary_dim]
rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj)
# apply rotation encoding to K
k1 = k[..., :rotary_dim]
k2 = k[..., rotary_dim : 2 * rotary_dim]
rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("nheads", [8, 16])
@pytest.mark.parametrize("seqlen", [128, 256])
@pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)])
@pytest.mark.parametrize("qk_dim", [3, 4])
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float32, 1e-5, 1e-5),
pytest.param(
torch.bfloat16,
1e-1,
1e-5,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
],
)
@pytest.mark.parametrize("conj", [False, True])
@pytest.mark.flaky(max_runs=2, min_passes=1)
def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj):
device = infer_device()
if device is None:
pytest.skip("No suitable device found for testing")
if qk_dim == 4:
q_shape = (batch_size, seqlen, nheads, headdim)
cos_sin_shape = (seqlen, 1, rotary_dim)
elif qk_dim == 3:
q_shape = (batch_size * seqlen, nheads, headdim)
cos_sin_shape = (batch_size * seqlen, 1, rotary_dim)
q_orig = torch.randn(q_shape, device=device, dtype=dtype)
k_orig = torch.randn(q_shape, device=device, dtype=dtype)
cos = torch.randn(cos_sin_shape, device=device, dtype=dtype)
sin = torch.randn(cos_sin_shape, device=device, dtype=dtype)
q_kernel, k_kernel = q_orig.clone(), k_orig.clone()
q_torch, k_torch = q_orig.clone(), k_orig.clone()
q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj)
apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj)
# verify the rotation results of Q and K are consistent
try:
assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match"
except AssertionError:
diff_q = torch.abs(q_torch_out - q_kernel)
max_diff_q = torch.max(diff_q)
print(f"Max difference for Q: {max_diff_q}")
raise
try:
assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match"
except AssertionError:
diff_k = torch.abs(k_torch_out - k_kernel)
max_diff_k = torch.max(diff_k)
print(f"Max difference for K: {max_diff_k}")
raise
# verify the non-rotated part of Q and K remains unchanged
if (2 * rotary_dim) < headdim:
assert torch.equal(
q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:]
), "Non-rotated part of Q should be unchanged"
assert torch.equal(
k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:]
), "Non-rotated part of K should be unchanged"
|