# 👟 Trainer An opinionated general purpose model trainer on PyTorch with a simple code base. ## Installation From Github: ```console git clone https://github.com/coqui-ai/Trainer cd Trainer make install ``` From PyPI: ```console pip install trainer ``` Prefer installing from Github as it is more stable. ## Implementing a model Subclass and overload the functions in the [```TrainerModel()```](trainer/model.py) ## Training a model See the test script [here](tests/test_train_mnist.py) training a basic MNIST model. ## Training with DDP ```console $ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1" ``` We don't use ```.spawn()``` to initiate multi-gpu training since it causes certain limitations. - Everything must the pickable. - ```.spawn()``` trains the model in subprocesses and the model in the main process is not updated. - DataLoader with N processes gets really slow when the N is large. ## Profiling example - Create the torch profiler as you like and pass it to the trainer. ```python import torch profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"), record_shapes=True, profile_memory=True, with_stack=True, ) prof = trainer.profile_fit(profiler, epochs=1, small_run=64) then run Tensorboard ``` - Run the tensorboard. ```console tensorboard --logdir="./profiler/" ``` ## Supported Experiment Loggers - [Tensorboard](https://www.tensorflow.org/tensorboard) - actively maintained - [ClearML](https://clear.ml/) - actively maintained - [MLFlow](https://mlflow.org/) - [Aim](https://aimstack.io/) - [WandDB](https://wandb.ai/) To add a new logger, you must subclass [BaseDashboardLogger](trainer/logging/base_dash_logger.py) and overload its functions.