File size: 1,255 Bytes
287c28c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import os
import time
import torch
from tests.utils.mnist import MnistModel, MnistModelConfig
from trainer import Trainer, TrainerArgs
from trainer.generic_utils import KeepAverage
is_cuda = torch.cuda.is_available()
def test_train_mnist():
model = MnistModel()
# Test StepwiseGradualLR
config = MnistModelConfig(
lr_scheduler="StepwiseGradualLR",
lr_scheduler_params={
"gradual_learning_rates": [
[0, 1e-3],
[2, 1e-4],
]
},
scheduler_after_epoch=False,
)
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None)
trainer.train_loader = trainer.get_train_dataloader(
trainer.training_assets,
trainer.train_samples,
verbose=True,
)
trainer.keep_avg_train = KeepAverage()
lr_0 = trainer.scheduler.get_lr()
trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 0, time.time())
lr_1 = trainer.scheduler.get_lr()
trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 1, time.time())
lr_2 = trainer.scheduler.get_lr()
assert lr_0 == 1e-3
assert lr_1 == 1e-3
assert lr_2 == 1e-4
|