File size: 4,163 Bytes
146b945 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# NOTE: Torch needs to be imported before the custom
# extensions. Otherwise libc10.so cannot be found.
import torch
from ._ops import ops
from typing import List, Tuple, Union
from torch import Tensor
from torch.optim.optimizer import Optimizer, ParamsT
class AdamATan2(Optimizer):
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
weight_decay: float = 1e-2,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
def _init_group(
self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
):
for p in group["params"]:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = torch.zeros((), dtype=torch.float32, device=p.device)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
state_steps.append(state["step"])
def step(self):
"""Perform a single optimization step."""
self._cuda_graph_capture_health_check()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_steps = []
beta1, beta2 = group["betas"]
self._init_group(
group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps
)
_adam_atan2(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
)
def _adam_atan2(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
) -> None:
if not params:
return
# We only support scalar lr.
assert not isinstance(lr, Tensor)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, exp_avgs, exp_avg_sqs, state_steps]
)
for (device, _), (
(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_state_steps,
),
_,
) in grouped_tensors.items():
torch._foreach_add_(device_state_steps, 1)
ops.adam_atan2_cuda_impl_(
device_params,
device_grads,
device_exp_avgs,
device_exp_avg_sqs,
device_state_steps,
lr,
beta1,
beta2,
weight_decay,
)
|