File size: 3,324 Bytes
146b945 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from typing import Tuple
import torch
import math
import pytest
from adam_atan2 import _adam_atan2
def _adam_atan2_reference_impl(
param: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
# Constant
step_size: float,
wd_step_size: float,
bias_correction2_sqrt: float,
beta1: float,
beta2: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
orig_dtype = param.dtype
# Cast to math type, fp32.
if orig_dtype != torch.float64:
param = param.to(torch.float32)
grad = grad.to(torch.float32)
exp_avg = exp_avg.to(torch.float32)
exp_avg_sq = exp_avg_sq.to(torch.float32)
# Math
# Reference implementation (PyTorch):
# https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py
param.mul_(1 - wd_step_size)
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt() / bias_correction2_sqrt
param.add_(torch.atan2(exp_avg, denom), alpha=-step_size)
return param.to(orig_dtype), exp_avg.to(orig_dtype), exp_avg_sq.to(orig_dtype)
@pytest.mark.parametrize("params_shape", [(1,), (4096,), (4096, 14336)])
@pytest.mark.parametrize("lr", [1e-3, 1e-4, 5e-4])
@pytest.mark.parametrize(
"dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]
)
def test_adam_atan2_backend(
params_shape,
lr,
dtype,
weight_decay=0.1,
beta1=0.9,
beta2=0.95,
init_std=0.02,
grad_std=0.001,
steps=100,
atol={
torch.float64: 1e-15,
torch.float32: 1e-6,
torch.float16: 0.002,
torch.bfloat16: 0.005,
},
):
torch.random.manual_seed(0)
# Reference
ref_param = torch.empty(params_shape, dtype=dtype, device="cuda").normal_(
std=init_std
)
ref_exp_avg = torch.zeros_like(ref_param, dtype=dtype)
ref_exp_avg_sq = torch.zeros_like(ref_param, dtype=dtype)
ref_steps = 0
# Test
test_param = ref_param.clone()
test_exp_avg = ref_exp_avg.clone()
test_exp_avg_sq = ref_exp_avg_sq.clone()
test_steps = torch.zeros((), dtype=torch.float32, device="cuda")
for _ in range(steps):
grad = torch.empty(params_shape, dtype=dtype, device="cuda").normal_(
std=grad_std
)
# Reference
ref_steps += 1
ref_param, ref_exp_avg, ref_exp_avg_sq = _adam_atan2_reference_impl(
ref_param,
grad,
ref_exp_avg,
ref_exp_avg_sq,
step_size=lr / (1 - beta1**ref_steps),
wd_step_size=lr * weight_decay,
bias_correction2_sqrt=math.sqrt(1 - beta2**ref_steps),
beta1=beta1,
beta2=beta2,
)
# Test
_adam_atan2(
[test_param],
[grad],
[test_exp_avg],
[test_exp_avg_sq],
[test_steps],
beta1,
beta2,
lr,
weight_decay,
)
# Check
assert torch.allclose(test_param, ref_param, rtol=0, atol=atol[dtype])
assert torch.allclose(test_exp_avg, ref_exp_avg, rtol=0, atol=atol[dtype])
assert torch.allclose(test_exp_avg_sq, ref_exp_avg_sq, rtol=0, atol=atol[dtype])
assert test_steps.item() == ref_steps
|