File size: 1,060 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR


def configure_optimizer(
    model: nn.Module,
    lr: float = 1e-3,
    weight_decay: float = 0.01,
    total_steps: int = 100
):
    """Return AdamW optimizer with OneCycleLR scheduler."""
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=total_steps)
    return optimizer, scheduler


def adjust_learning_rate(optimizer: torch.optim.Optimizer, factor: float) -> float:
    """Scale the learning rate of all param groups by ``factor``.

    Parameters
    ----------
    optimizer:
        The optimizer whose learning rate will be adjusted.
    factor:
        Multiplicative factor applied to the current learning rate.

    Returns
    -------
    float
        The updated learning rate of the first parameter group.
    """
    for param_group in optimizer.param_groups:
        param_group["lr"] *= factor
    return optimizer.param_groups[0]["lr"]