File size: 5,632 Bytes
287c28c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
class DistributedSamplerWrapper(DistributedSampler):
"""Wrapper over Sampler for distributed training. It allows you to use any sampler in distributed mode.
It is especially useful in conjunction with torch.nn.parallel.DistributedDataParallel. In such a case, each
process can pass a torch.utils.data.DistributedSampler instance as a torch.utils.data.DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note:
Dataset is assumed to be of constant size.
Args:
sampler: Sampler used for subsampling.
num_replicas (int, optional): Number of processes participating in distributed training. By default,
world_size is retrieved from the current distributed group.
rank (int, optional): Rank of the current process within num_replicas. By default, rank is retrieved
from the current distributed group.
shuffle (bool, optional): If True, sampler will shuffle the indices. Default: True.
seed (int, optional): random seed used to shuffle the sampler if shuffle=True. This number should be
identical across all processes in the distributed group. Default: 0.
Reference: https://github.com/pytorch/pytorch/issues/23430
"""
def __init__(
self,
sampler,
num_replicas: int = None,
rank: int = None,
shuffle: bool = True,
seed: int = 0,
):
super().__init__(
sampler,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
)
def __iter__(self):
indices = list(self.dataset)[: self.total_size]
# Add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size, f"{len(indices)} != {self.total_size}"
# Subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples, f"{len(indices)} != {self.num_samples}"
return iter(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
elif hasattr(self.dataset, "generator"):
self.dataset.generator = torch.Generator().manual_seed(self.seed + epoch)
def state_dict(self):
return self.dataset.state_dict()
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict)
# pylint: disable=protected-access
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
class NoamLRStepConstant(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1, threshold_step=100):
self.warmup_steps = float(warmup_steps)
self.threshold_step = threshold_step
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = min(max(self.last_epoch, 1), self.threshold_step)
return [
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
class NoamLRStepDecay(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1, threshold_step=100):
self.warmup_steps = float(warmup_steps)
self.threshold_step = threshold_step
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
if step >= self.threshold_step:
self.threshold_step -= 1
step = max(self.threshold_step, 1)
return [
base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
# pylint: disable=protected-access
class StepwiseGradualLR(torch.optim.lr_scheduler._LRScheduler):
"""Hardcoded step-wise learning rate scheduling.
Necessary for CapacitronVAE"""
def __init__(self, optimizer, gradual_learning_rates, last_epoch=-1):
self.gradual_learning_rates = gradual_learning_rates
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
step_thresholds = []
rates = []
for values in self.gradual_learning_rates:
step_thresholds.append(values[0])
rates.append(values[1])
boolean_indeces = np.less_equal(step_thresholds, step)
try:
last_true = np.where(boolean_indeces == True)[0][-1] # pylint: disable=singleton-comparison
except IndexError:
# For the steps larger than the last step in the list
pass
lr = rates[np.max(last_true, 0)]
# Return last lr if step is above the set threshold
lr = rates[-1] if step > step_thresholds[-1] else lr
# Return first lr if step is below the second threshold - first is initial lr
lr = rates[0] if step < step_thresholds[1] else lr
return np.tile(lr, len(self.base_lrs)) # hack?
|