|
import numpy as np |
|
import torch |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
class DistributedSamplerWrapper(DistributedSampler): |
|
"""Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode. |
|
It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each |
|
process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler, |
|
and load a subset of the original dataset that is exclusive to it. |
|
|
|
.. note: |
|
Dataset is assumed to be of constant size. |
|
|
|
Args: |
|
sampler: Sampler used for subsampling. |
|
num_replicas (int, optional): Number of processes participating in distributed training. By default, |
|
world_size is retrieved from the current distributed group. |
|
rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved |
|
from the current distributed group. |
|
shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True. |
|
seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be |
|
identical across all processes in the distributed group. Default: 0. |
|
|
|
Reference: https://github.com/pytorch/pytorch/issues/23430 |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
sampler, |
|
num_replicas: int = None, |
|
rank: int = None, |
|
shuffle: bool = True, |
|
seed: int = 0, |
|
): |
|
super().__init__( |
|
sampler, |
|
num_replicas=num_replicas, |
|
rank=rank, |
|
shuffle=shuffle, |
|
seed=seed, |
|
) |
|
|
|
def __iter__(self): |
|
indices = list(self.dataset)[: self.total_size] |
|
|
|
|
|
indices += indices[: (self.total_size - len(indices))] |
|
assert len(indices) == self.total_size, f"{len(indices)} != {self.total_size}" |
|
|
|
|
|
offset = self.num_samples * self.rank |
|
indices = indices[offset : offset + self.num_samples] |
|
assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}" |
|
|
|
return iter(indices) |
|
|
|
def set_epoch(self, epoch): |
|
super().set_epoch(epoch) |
|
if hasattr(self.dataset, "set_epoch"): |
|
self.dataset.set_epoch(epoch) |
|
elif hasattr(self.dataset, "generator"): |
|
self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch) |
|
|
|
def state_dict(self): |
|
return self.dataset.state_dict() |
|
|
|
def load_state_dict(self, state_dict): |
|
self.dataset.load_state_dict(state_dict) |
|
|
|
|
|
|
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler): |
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): |
|
self.warmup_steps = float(warmup_steps) |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
step = max(self.last_epoch, 1) |
|
return [ |
|
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) |
|
for base_lr in self.base_lrs |
|
] |
|
|
|
|
|
class NoamLRStepConstant(torch.optim.lr_scheduler._LRScheduler): |
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1, threshold_step=100): |
|
self.warmup_steps = float(warmup_steps) |
|
self.threshold_step = threshold_step |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
step = min(max(self.last_epoch, 1), self.threshold_step) |
|
return [ |
|
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) |
|
for base_lr in self.base_lrs |
|
] |
|
|
|
|
|
class NoamLRStepDecay(torch.optim.lr_scheduler._LRScheduler): |
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1, threshold_step=100): |
|
self.warmup_steps = float(warmup_steps) |
|
self.threshold_step = threshold_step |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
step = max(self.last_epoch, 1) |
|
if step >= self.threshold_step: |
|
self.threshold_step -= 1 |
|
step = max(self.threshold_step, 1) |
|
return [ |
|
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5) |
|
for base_lr in self.base_lrs |
|
] |
|
|
|
|
|
class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler): |
|
"""Hardcoded step-wise learning rate scheduling. |
|
Necessary for CapacitronVAE""" |
|
|
|
def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1): |
|
self.gradual_learning_rates = gradual_learning_rates |
|
super().__init__(optimizer, last_epoch) |
|
|
|
def get_lr(self): |
|
step = max(self.last_epoch, 1) |
|
step_thresholds = [] |
|
rates = [] |
|
for values in self.gradual_learning_rates: |
|
step_thresholds.append(values[0]) |
|
rates.append(values[1]) |
|
|
|
boolean_indeces = np.less_equal(step_thresholds, step) |
|
try: |
|
last_true = np.where(boolean_indeces == True)[0][-1] |
|
except IndexError: |
|
|
|
pass |
|
lr = rates[np.max(last_true, 0)] |
|
|
|
|
|
lr = rates[-1] if step > step_thresholds[-1] else lr |
|
|
|
lr = rates[0] if step < step_thresholds[1] else lr |
|
|
|
return np.tile(lr, len(self.base_lrs)) |
|
|