File size: 785 Bytes
287c28c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from distutils.command.config import config
from mnist import MnistModel, MnistModelConfig
from trainer import Trainer, TrainerArgs
def main():
"""Run `MNIST` model training from scratch or from previous checkpoint."""
# init args and config
train_args = TrainerArgs()
config = MnistModelConfig()
# init the model from config
model = MnistModel()
# init the trainer and π
trainer = Trainer(
train_args,
config,
config.output_path,
model=model,
train_samples=model.get_data_loader(config, None, False, None, None, None),
eval_samples=model.get_data_loader(config, None, True, None, None, None),
parse_command_line_args=True,
)
trainer.fit()
if __name__ == "__main__":
main()
|