Sarthak
chore: moved model2vec as in internal package
473c3a0
# Training
Aside from [distillation](../../README.md#distillation), `model2vec` also supports training simple classifiers on top of static models, using [pytorch](https://pytorch.org/), [lightning](https://lightning.ai/) and [scikit-learn](https://scikit-learn.org/stable/index.html).
We support both single and multi-label classification, which work seamlessly based on the labels you provide.
# Installation
To train, make sure you install the training extra:
```
pip install model2vec[training]
```
# Quickstart
To train a model, simply initialize it using a `StaticModel`, or from a pre-trained model, as follows:
```python
from model2vec.distill import distill
from model2vec.train import StaticModelForClassification
# From a distilled model
distilled_model = distill("baai/bge-base-en-v1.5")
classifier = StaticModelForClassification.from_static_model(model=distilled_model)
# From a pre-trained model: potion is the default
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")
```
This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for `from_pretrained` is [potion-base-32m](https://huggingface.co/minishlab/potion-base-32M), our best model to date. This is our recommended path if you're working with general English data.
Now that you have created the classifier, let's just train a model. The example below assumes you have the [`datasets`](https://github.com/huggingface/datasets) library installed.
```python
import numpy as np
from datasets import load_dataset
# Load the subj dataset
ds = load_dataset("setfit/subj")
train = ds["train"]
test = ds["test"]
s = perf_counter()
classifier = classifier.fit(train["text"], train["label"])
print(f"Training took {int(perf_counter() - s)} seconds.")
# Training took 81 seconds
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
print(classification_report)
# Achieved 91.0 test accuracy
```
As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training.
The training loop is handled by [`lightning`](https://pypi.org/project/lightning/). By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5.
Note that this model is as fast as you're used to from us:
```python
from time import perf_counter
s = perf_counter()
classifier.predict(test["text"])
print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.")
# Took 67 milliseconds for 2000 instances on CPU.
```
## Multi-label classification
Multi-label classification is supported out of the box. Just pass a list of lists to the `fit` function (e.g. `[[label1, label2], [label1, label3]]`), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the [go_emotions](https://huggingface.co/datasets/google-research-datasets/go_emotions) dataset:
```python
from datasets import load_dataset
from model2vec.train import StaticModelForClassification
# Initialize a classifier from a pre-trained model
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
# Load a multi-label dataset
ds = load_dataset("google-research-datasets/go_emotions")
# Inspect some of the labels
print(ds["train"]["labels"][40:50])
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]
# Train the classifier on text (X) and labels (y)
classifier.fit(ds["train"]["text"], ds["train"]["labels"])
```
Then, we can evaluate the classifier:
```python
from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
print(classification_report)
# Accuracy: 0.410
# Precision: 0.527
# Recall: 0.410
# F1: 0.439
```
The scores are competitive with the popular [roberta-base-go_emotions](https://huggingface.co/SamLowe/roberta-base-go_emotions) model, while our model is orders of magnitude faster.
# Persistence
You can turn a classifier into a scikit-learn compatible pipeline, as follows:
```python
pipeline = classifier.to_pipeline()
```
This pipeline object can be persisted using standard pickle-based methods, such as [joblib](https://joblib.readthedocs.io/en/stable/). This makes it easy to use your model in inferene pipelines (no installing torch!), although `joblib` and `pickle` should not be used to share models outside of your organization.
If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions:
```python
pipeline.save_pretrained(path)
pipeline.push_to_hub("my_cool/project")
```
Later, you can load these as follows:
```python
from model2vec.inference import StaticModelPipeline
pipeline = StaticModelPipeline.from_pretrained("my_cool/project")
```
Loading pipelines in this way is _extremely_ fast. It takes only 30ms to load a pipeline from disk.
# Bring your own architecture
Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer `StaticModelForClassification`, but in the future we'll also offer regression, etc.
The core functionality of the `StaticModelForClassification` is contained in a couple of functions:
* `construct_head`: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior.
* `train_test_split`: governs the train test split before classification.
* `prepare_dataset`: Selects the `torch.Dataset` that will be used in the `Dataloader` during training.
* `_encode`: The encoding function used in the model.
* `fit`: contains all the lightning-related fitting logic.
The training of the model is done in a `lighting.LightningModule`, which can be modified but is very basic.
# Results
We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the [training results](https://github.com/MinishLab/model2vec/tree/main/results#training-results) documentation.