import os | |
import torch | |
from tests.utils.mnist import MnistModel, MnistModelConfig | |
from trainer import Trainer, TrainerArgs | |
is_cuda = torch.cuda.is_available() | |
def test_train_mnist(): | |
model = MnistModel() | |
trainer = Trainer( | |
TrainerArgs(), MnistModelConfig(), model=model, output_path=os.getcwd(), gpu=0 if is_cuda else None | |
) | |
trainer.fit() | |
loss1 = trainer.keep_avg_train["avg_loss"] | |
trainer.fit() | |
loss2 = trainer.keep_avg_train["avg_loss"] | |
assert loss1 > loss2 | |