|
import os |
|
import time |
|
|
|
import torch |
|
|
|
from tests.utils.mnist import MnistModel, MnistModelConfig |
|
from trainer import Trainer, TrainerArgs |
|
from trainer.generic_utils import KeepAverage |
|
|
|
is_cuda = torch.cuda.is_available() |
|
|
|
|
|
def test_train_mnist(): |
|
model = MnistModel() |
|
|
|
config = MnistModelConfig( |
|
lr_scheduler="StepwiseGradualLR", |
|
lr_scheduler_params={ |
|
"gradual_learning_rates": [ |
|
[0, 1e-3], |
|
[2, 1e-4], |
|
] |
|
}, |
|
scheduler_after_epoch=False, |
|
) |
|
trainer = Trainer(TrainerArgs(), config, model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None) |
|
trainer.train_loader = trainer.get_train_dataloader( |
|
trainer.training_assets, |
|
trainer.train_samples, |
|
verbose=True, |
|
) |
|
trainer.keep_avg_train = KeepAverage() |
|
|
|
lr_0 = trainer.scheduler.get_lr() |
|
trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 0, time.time()) |
|
lr_1 = trainer.scheduler.get_lr() |
|
trainer.train_step(next(iter(trainer.train_loader)), len(trainer.train_loader), 1, time.time()) |
|
lr_2 = trainer.scheduler.get_lr() |
|
assert lr_0 == 1e-3 |
|
assert lr_1 == 1e-3 |
|
assert lr_2 == 1e-4 |
|
|