|
# Training |
|
|
|
Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html). |
|
|
|
We support both single and multi-label classification, which work seamlessly based on the labels you provide. |
|
|
|
# Installation |
|
|
|
To train, make sure you install the training extra: |
|
|
|
``` |
|
pip install model2vec[training] |
|
``` |
|
|
|
# Quickstart |
|
|
|
To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows: |
|
|
|
```python |
|
from model2vec.distill import distill |
|
from model2vec.train import StaticModelForClassification |
|
|
|
# From a distilled model |
|
distilled_model = distill("baai/bge-base-en-v1.5") |
|
classifier = StaticModelForClassification.from_static_model(model=distilled_model) |
|
|
|
# From a pre-trained model: potion is the default |
|
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m") |
|
``` |
|
|
|
This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data. |
|
|
|
Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed. |
|
|
|
```python |
|
import numpy as np |
|
from datasets import load_dataset |
|
|
|
# Load the subj dataset |
|
ds = load_dataset("setfit/subj") |
|
train = ds["train"] |
|
test = ds["test"] |
|
|
|
s = perf_counter() |
|
classifier = classifier.fit(train["text"], train["label"]) |
|
|
|
print(f"Training took {int(perf_counter() - s)} seconds.") |
|
# Training took 81 seconds |
|
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"]) |
|
print(classification_report) |
|
# Achieved 91.0 test accuracy |
|
``` |
|
|
|
As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training. |
|
|
|
The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5. |
|
|
|
Note that this model is as fast as you're used to from us: |
|
|
|
```python |
|
from time import perf_counter |
|
|
|
s = perf_counter() |
|
classifier.predict(test["text"]) |
|
print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.") |
|
# Took 67 milliseconds for 2000 instances on CPU. |
|
``` |
|
|
|
## Multi-label classification |
|
|
|
Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset: |
|
|
|
```python |
|
from datasets import load_dataset |
|
from model2vec.train import StaticModelForClassification |
|
|
|
# Initialize a classifier from a pre-trained model |
|
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M") |
|
|
|
# Load a multi-label dataset |
|
ds = load_dataset("google-research-datasets/go_emotions") |
|
|
|
# Inspect some of the labels |
|
print(ds["train"]["labels"][40:50]) |
|
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]] |
|
|
|
# Train the classifier on text (X) and labels (y) |
|
classifier.fit(ds["train"]["text"], ds["train"]["labels"]) |
|
``` |
|
|
|
Then, we can evaluate the classifier: |
|
|
|
```python |
|
from sklearn import metrics |
|
from sklearn.preprocessing import MultiLabelBinarizer |
|
|
|
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3) |
|
print(classification_report) |
|
# Accuracy: 0.410 |
|
# Precision: 0.527 |
|
# Recall: 0.410 |
|
# F1: 0.439 |
|
``` |
|
|
|
The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster. |
|
|
|
# Persistence |
|
|
|
You can turn a classifier into a scikit-learn compatible pipeline, as follows: |
|
|
|
```python |
|
pipeline = classifier.to_pipeline() |
|
``` |
|
|
|
This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization. |
|
|
|
If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions: |
|
|
|
```python |
|
pipeline.save_pretrained(path) |
|
pipeline.push_to_hub("my_cool/project") |
|
``` |
|
|
|
Later, you can load these as follows: |
|
|
|
```python |
|
from model2vec.inference import StaticModelPipeline |
|
|
|
pipeline = StaticModelPipeline.from_pretrained("my_cool/project") |
|
``` |
|
|
|
Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk. |
|
|
|
|
|
# Bring your own architecture |
|
|
|
Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc. |
|
|
|
The core functionality of the `StaticModelForClassification` is contained in a couple of functions: |
|
|
|
* `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior. |
|
* `train_test_split`: governs the train test split before classification. |
|
* `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training. |
|
* `_encode`: The encoding function used in the model. |
|
* `fit`: contains all the lightning-related fitting logic. |
|
|
|
The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic. |
|
|
|
# Results |
|
|
|
We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation. |
|
|