File size: 3,409 Bytes
3215d8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
Metadata-Version: 2.1
Name: trainer
Version: 0.0.12
Summary: General purpose model trainer for PyTorch that is more flexible than it should be, by 🐸Coqui.
Home-page: https://github.com/coqui-ai/Trainer
Author: Eren Gölge
Author-email: egolge@coqui.ai
License: Apache2
Project-URL: Documentation, https://github.com/coqui-ai/Trainer/
Project-URL: Tracker, https://github.com/coqui-ai/Trainer/issues
Project-URL: Repository, https://github.com/coqui-ai/Trainer
Project-URL: Discussions, https://github.com/coqui-ai/Trainer/discussions
Classifier: Environment :: Console
Classifier: Natural Language :: English
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Requires-Python: >=3.6.0, <3.11
Description-Content-Type: text/markdown
Provides-Extra: dev
Provides-Extra: test
Provides-Extra: all

<p align="center"><img src="https://user-images.githubusercontent.com/1402048/151947958-0bcadf38-3a82-4b4e-96b4-a38d3721d737.png" align="right" height="255px" /></p>

# 👟 Trainer
An opinionated general purpose model trainer on PyTorch with a simple code base.

## Installation

From Github:

```console
git clone https://github.com/coqui-ai/Trainer
cd Trainer
make install
```

From PyPI:

```console
pip install trainer
```

Prefer installing from Github as it is more stable.

## Implementing a model
Subclass and overload the functions in the [```TrainerModel()```](trainer/model.py)

## Training a model
See the test script [here](tests/test_train_mnist.py) training a basic MNIST model.

## Training with DDP

```console
$ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1"
```

We don't use ```.spawn()``` to initiate multi-gpu training since it causes certain limitations.

- Everything must the pickable.
- ```.spawn()``` trains the model in subprocesses and the model in the main process is not updated.
- DataLoader with N processes gets really slow when the N is large.

## Profiling example

- Create the torch profiler as you like and pass it to the trainer.
    ```python
    import torch
    profiler = torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    )
    prof = trainer.profile_fit(profiler, epochs=1, small_run=64)
    then run Tensorboard
    ```
- Run the tensorboard.
    ```console
    tensorboard --logdir="./profiler/"
    ```

## Supported Experiment Loggers
- [Tensorboard](https://www.tensorflow.org/tensorboard) - actively maintained
- [ClearML](https://clear.ml/) - actively maintained
- [MLFlow](https://mlflow.org/)
- [Aim](https://aimstack.io/)
- [WandDB](https://wandb.ai/)

To add a new logger, you must subclass [BaseDashboardLogger](trainer/logging/base_dash_logger.py) and overload its functions.