Sarthak
chore: moved model2vec as in internal package
473c3a0

Training

Aside from distillation, model2vec also supports training simple classifiers on top of static models, using pytorch, lightning and scikit-learn.

We support both single and multi-label classification, which work seamlessly based on the labels you provide.

Installation

To train, make sure you install the training extra:

pip install model2vec[training]

Quickstart

To train a model, simply initialize it using a StaticModel, or from a pre-trained model, as follows:

from model2vec.distill import distill
from model2vec.train import StaticModelForClassification

# From a distilled model
distilled_model = distill("baai/bge-base-en-v1.5")
classifier = StaticModelForClassification.from_static_model(model=distilled_model)

# From a pre-trained model: potion is the default
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")

This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for from_pretrained is potion-base-32m, our best model to date. This is our recommended path if you're working with general English data.

Now that you have created the classifier, let's just train a model. The example below assumes you have the datasets library installed.

import numpy as np
from datasets import load_dataset

# Load the subj dataset
ds = load_dataset("setfit/subj")
train = ds["train"]
test = ds["test"]

s = perf_counter()
classifier = classifier.fit(train["text"], train["label"])

print(f"Training took {int(perf_counter() - s)} seconds.")
# Training took 81 seconds
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
print(classification_report)
# Achieved 91.0 test accuracy

As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training.

The training loop is handled by lightning. By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5.

Note that this model is as fast as you're used to from us:

from time import perf_counter

s = perf_counter()
classifier.predict(test["text"])
print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.")
# Took 67 milliseconds for 2000 instances on CPU.

Multi-label classification

Multi-label classification is supported out of the box. Just pass a list of lists to the fit function (e.g. [[label1, label2], [label1, label3]]), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the go_emotions dataset:

from datasets import load_dataset
from model2vec.train import StaticModelForClassification

# Initialize a classifier from a pre-trained model
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")

# Load a multi-label dataset
ds = load_dataset("google-research-datasets/go_emotions")

# Inspect some of the labels
print(ds["train"]["labels"][40:50])
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]

# Train the classifier on text (X) and labels (y)
classifier.fit(ds["train"]["text"], ds["train"]["labels"])

Then, we can evaluate the classifier:

from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer

classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
print(classification_report)
# Accuracy: 0.410
# Precision: 0.527
# Recall: 0.410
# F1: 0.439

The scores are competitive with the popular roberta-base-go_emotions model, while our model is orders of magnitude faster.

Persistence

You can turn a classifier into a scikit-learn compatible pipeline, as follows:

pipeline = classifier.to_pipeline()

This pipeline object can be persisted using standard pickle-based methods, such as joblib. This makes it easy to use your model in inferene pipelines (no installing torch!), although joblib and pickle should not be used to share models outside of your organization.

If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions:

pipeline.save_pretrained(path)
pipeline.push_to_hub("my_cool/project")

Later, you can load these as follows:

from model2vec.inference import StaticModelPipeline

pipeline = StaticModelPipeline.from_pretrained("my_cool/project")

Loading pipelines in this way is extremely fast. It takes only 30ms to load a pipeline from disk.

Bring your own architecture

Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer StaticModelForClassification, but in the future we'll also offer regression, etc.

The core functionality of the StaticModelForClassification is contained in a couple of functions:

  • construct_head: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior.
  • train_test_split: governs the train test split before classification.
  • prepare_dataset: Selects the torch.Dataset that will be used in the Dataloader during training.
  • _encode: The encoding function used in the model.
  • fit: contains all the lightning-related fitting logic.

The training of the model is done in a lighting.LightningModule, which can be modified but is very basic.

Results

We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the training results documentation.