File size: 3,324 Bytes
146b945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Tuple
import torch

import math
import pytest

from adam_atan2 import _adam_atan2


def _adam_atan2_reference_impl(
    param: torch.Tensor,
    grad: torch.Tensor,
    exp_avg: torch.Tensor,
    exp_avg_sq: torch.Tensor,
    # Constant
    step_size: float,
    wd_step_size: float,
    bias_correction2_sqrt: float,
    beta1: float,
    beta2: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    orig_dtype = param.dtype

    # Cast to math type, fp32.
    if orig_dtype != torch.float64:
        param = param.to(torch.float32)
        grad = grad.to(torch.float32)
        exp_avg = exp_avg.to(torch.float32)
        exp_avg_sq = exp_avg_sq.to(torch.float32)

    # Math
    # Reference implementation (PyTorch):
    # https://github.com/pytorch/pytorch/blob/main/torch/optim/adamw.py
    param.mul_(1 - wd_step_size)

    exp_avg.lerp_(grad, 1 - beta1)
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

    denom = exp_avg_sq.sqrt() / bias_correction2_sqrt
    param.add_(torch.atan2(exp_avg, denom), alpha=-step_size)

    return param.to(orig_dtype), exp_avg.to(orig_dtype), exp_avg_sq.to(orig_dtype)


@pytest.mark.parametrize("params_shape", [(1,), (4096,), (4096, 14336)])
@pytest.mark.parametrize("lr", [1e-3, 1e-4, 5e-4])
@pytest.mark.parametrize(
    "dtype", [torch.float64, torch.float32, torch.float16, torch.bfloat16]
)
def test_adam_atan2_backend(
    params_shape,
    lr,
    dtype,
    weight_decay=0.1,
    beta1=0.9,
    beta2=0.95,
    init_std=0.02,
    grad_std=0.001,
    steps=100,
    atol={
        torch.float64: 1e-15,
        torch.float32: 1e-6,
        torch.float16: 0.002,
        torch.bfloat16: 0.005,
    },
):
    torch.random.manual_seed(0)

    # Reference
    ref_param = torch.empty(params_shape, dtype=dtype, device="cuda").normal_(
        std=init_std
    )
    ref_exp_avg = torch.zeros_like(ref_param, dtype=dtype)
    ref_exp_avg_sq = torch.zeros_like(ref_param, dtype=dtype)
    ref_steps = 0

    # Test
    test_param = ref_param.clone()
    test_exp_avg = ref_exp_avg.clone()
    test_exp_avg_sq = ref_exp_avg_sq.clone()
    test_steps = torch.zeros((), dtype=torch.float32, device="cuda")

    for _ in range(steps):
        grad = torch.empty(params_shape, dtype=dtype, device="cuda").normal_(
            std=grad_std
        )

        # Reference
        ref_steps += 1
        ref_param, ref_exp_avg, ref_exp_avg_sq = _adam_atan2_reference_impl(
            ref_param,
            grad,
            ref_exp_avg,
            ref_exp_avg_sq,
            step_size=lr / (1 - beta1**ref_steps),
            wd_step_size=lr * weight_decay,
            bias_correction2_sqrt=math.sqrt(1 - beta2**ref_steps),
            beta1=beta1,
            beta2=beta2,
        )

        # Test
        _adam_atan2(
            [test_param],
            [grad],
            [test_exp_avg],
            [test_exp_avg_sq],
            [test_steps],
            beta1,
            beta2,
            lr,
            weight_decay,
        )

    # Check
    assert torch.allclose(test_param, ref_param, rtol=0, atol=atol[dtype])
    assert torch.allclose(test_exp_avg, ref_exp_avg, rtol=0, atol=atol[dtype])
    assert torch.allclose(test_exp_avg_sq, ref_exp_avg_sq, rtol=0, atol=atol[dtype])
    assert test_steps.item() == ref_steps