|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from torch.optim import Optimizer |
|
|
|
|
|
class LRScheduler(object): |
|
""" |
|
Base-class for learning rate schedulers where the learning-rate depends on both the |
|
batch and the epoch. |
|
""" |
|
|
|
def __init__(self, optimizer: Optimizer, verbose: bool = False): |
|
|
|
if not isinstance(optimizer, Optimizer): |
|
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) |
|
self.optimizer = optimizer |
|
self.verbose = verbose |
|
|
|
for group in optimizer.param_groups: |
|
group.setdefault("base_lr", group["lr"]) |
|
|
|
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] |
|
|
|
self.epoch = 0 |
|
self.batch = 0 |
|
|
|
def state_dict(self): |
|
"""Returns the state of the scheduler as a :class:`dict`. |
|
|
|
It contains an entry for every variable in self.__dict__ which |
|
is not the optimizer. |
|
""" |
|
return { |
|
|
|
|
|
|
|
"epoch": self.epoch, |
|
"batch": self.batch, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
"""Loads the schedulers state. |
|
|
|
Args: |
|
state_dict (dict): scheduler state. Should be an object returned |
|
from a call to :meth:`state_dict`. |
|
""" |
|
|
|
|
|
base_lrs = self.base_lrs |
|
self.__dict__.update(state_dict) |
|
self.base_lrs = base_lrs |
|
|
|
def get_last_lr(self) -> List[float]: |
|
"""Return last computed learning rate by current scheduler. |
|
Will be a list of float.""" |
|
return self._last_lr |
|
|
|
def get_lr(self): |
|
|
|
|
|
|
|
|
|
raise NotImplementedError |
|
|
|
def step_batch(self, batch: Optional[int] = None) -> None: |
|
|
|
|
|
|
|
|
|
|
|
if batch is not None: |
|
self.batch = batch |
|
else: |
|
self.batch = self.batch + 1 |
|
self._set_lrs() |
|
|
|
def step_epoch(self, epoch: Optional[int] = None): |
|
|
|
|
|
|
|
if epoch is not None: |
|
self.epoch = epoch |
|
else: |
|
self.epoch = self.epoch + 1 |
|
self._set_lrs() |
|
|
|
def _set_lrs(self): |
|
values = self.get_lr() |
|
assert len(values) == len(self.optimizer.param_groups) |
|
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)): |
|
param_group, lr = data |
|
param_group["lr"] = lr |
|
self.print_lr(self.verbose, i, lr) |
|
self._last_lr = [group["lr"] for group in self.optimizer.param_groups] |
|
|
|
def print_lr(self, is_verbose, group, lr): |
|
"""Display the current learning rate.""" |
|
if is_verbose: |
|
logging.warning( |
|
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" |
|
f" of group {group} to {lr:.4e}." |
|
) |
|
|
|
|
|
class Eden(LRScheduler): |
|
""" |
|
Eden scheduler. |
|
The basic formula (before warmup) is: |
|
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * |
|
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup |
|
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches |
|
and then stays constant at 1. |
|
|
|
If you don't have the concept of epochs, or one epoch takes a very long time, |
|
you can replace the notion of 'epoch' with some measure of the amount of data |
|
processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to |
|
some measure representing "quite a lot of data": say, one fifth or one third |
|
of an entire training run, but it doesn't matter much. You could also use |
|
Eden2 which has only the notion of batches. |
|
|
|
We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam |
|
|
|
Args: |
|
optimizer: the optimizer to change the learning rates on |
|
lr_batches: the number of batches after which we start significantly |
|
decreasing the learning rate, suggest 5000. |
|
lr_epochs: the number of epochs after which we start significantly |
|
decreasing the learning rate, suggest 6 if you plan to do e.g. |
|
20 to 40 epochs, but may need smaller number if dataset is huge |
|
and you will do few epochs. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optimizer: Optimizer, |
|
lr_batches: Union[int, float], |
|
lr_epochs: Union[int, float], |
|
warmup_batches: Union[int, float] = 500.0, |
|
warmup_start: float = 0.5, |
|
verbose: bool = False, |
|
): |
|
super(Eden, self).__init__(optimizer, verbose) |
|
self.lr_batches = lr_batches |
|
self.lr_epochs = lr_epochs |
|
self.warmup_batches = warmup_batches |
|
|
|
assert 0.0 <= warmup_start <= 1.0, warmup_start |
|
self.warmup_start = warmup_start |
|
|
|
def get_lr(self): |
|
factor = ( |
|
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2 |
|
) ** -0.25 * ( |
|
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 |
|
) |
|
warmup_factor = ( |
|
1.0 |
|
if self.batch >= self.warmup_batches |
|
else self.warmup_start |
|
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches) |
|
|
|
) |
|
|
|
return [x * factor * warmup_factor for x in self.base_lrs] |
|
|
|
|
|
class FixedLRScheduler(LRScheduler): |
|
""" |
|
Fixed learning rate scheduler. |
|
|
|
Args: |
|
optimizer: the optimizer to change the learning rates on |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optimizer: Optimizer, |
|
verbose: bool = False, |
|
): |
|
super(FixedLRScheduler, self).__init__(optimizer, verbose) |
|
|
|
def get_lr(self): |
|
|
|
return [x for x in self.base_lrs] |
|
|
|
|
|
def _test_eden(): |
|
m = torch.nn.Linear(100, 100) |
|
from zipvoice.utils.optim import ScaledAdam |
|
|
|
optim = ScaledAdam(m.parameters(), lr=0.03) |
|
|
|
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) |
|
|
|
for epoch in range(10): |
|
scheduler.step_epoch(epoch) |
|
|
|
for step in range(20): |
|
x = torch.randn(200, 100).detach() |
|
x.requires_grad = True |
|
y = m(x) |
|
dy = torch.randn(200, 100).detach() |
|
f = (y * dy).sum() |
|
f.backward() |
|
|
|
optim.step() |
|
scheduler.step_batch() |
|
optim.zero_grad() |
|
|
|
logging.info(f"last lr = {scheduler.get_last_lr()}") |
|
logging.info(f"state dict = {scheduler.state_dict()}") |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.set_num_threads(1) |
|
torch.set_num_interop_threads(1) |
|
logging.getLogger().setLevel(logging.INFO) |
|
import subprocess |
|
|
|
s = subprocess.check_output( |
|
"git status -uno .; git log -1; git diff HEAD .", shell=True |
|
) |
|
logging.info(s) |
|
|
|
_test_eden() |
|
|