File size: 785 Bytes
287c28c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from distutils.command.config import config

from mnist import MnistModel, MnistModelConfig

from trainer import Trainer, TrainerArgs


def main():
    """Run `MNIST` model training from scratch or from previous checkpoint."""
    # init args and config
    train_args = TrainerArgs()
    config = MnistModelConfig()

    # init the model from config
    model = MnistModel()

    # init the trainer and πŸš€
    trainer = Trainer(
        train_args,
        config,
        config.output_path,
        model=model,
        train_samples=model.get_data_loader(config, None, False, None, None, None),
        eval_samples=model.get_data_loader(config, None, True, None, None, None),
        parse_command_line_args=True,
    )
    trainer.fit()


if __name__ == "__main__":
    main()