Training
Aside from distillation, model2vec
also supports training simple classifiers on top of static models, using pytorch, lightning and scikit-learn.
We support both single and multi-label classification, which work seamlessly based on the labels you provide.
Installation
To train, make sure you install the training extra:
pip install model2vec[training]
Quickstart
To train a model, simply initialize it using a StaticModel
, or from a pre-trained model, as follows:
from model2vec.distill import distill
from model2vec.train import StaticModelForClassification
# From a distilled model
distilled_model = distill("baai/bge-base-en-v1.5")
classifier = StaticModelForClassification.from_static_model(model=distilled_model)
# From a pre-trained model: potion is the default
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32m")
This creates a very simple classifier: a StaticModel with a single 512-unit hidden layer on top. You can adjust the number of hidden layers and the number units through some parameters on both functions. Note that the default for from_pretrained
is potion-base-32m, our best model to date. This is our recommended path if you're working with general English data.
Now that you have created the classifier, let's just train a model. The example below assumes you have the datasets
library installed.
import numpy as np
from datasets import load_dataset
# Load the subj dataset
ds = load_dataset("setfit/subj")
train = ds["train"]
test = ds["test"]
s = perf_counter()
classifier = classifier.fit(train["text"], train["label"])
print(f"Training took {int(perf_counter() - s)} seconds.")
# Training took 81 seconds
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["label"])
print(classification_report)
# Achieved 91.0 test accuracy
As you can see, we got a pretty nice 91% accuracy, with only 81 seconds of training.
The training loop is handled by lightning
. By default the training loop splits the data into a train and validation split, with 90% of the data being used for training and 10% for validation. By default, it runs with early stopping on the validation set accuracy, with a patience of 5.
Note that this model is as fast as you're used to from us:
from time import perf_counter
s = perf_counter()
classifier.predict(test["text"])
print(f"Took {int((perf_counter() - s) * 1000)} milliseconds for {len(test)} instances on CPU.")
# Took 67 milliseconds for 2000 instances on CPU.
Multi-label classification
Multi-label classification is supported out of the box. Just pass a list of lists to the fit
function (e.g. [[label1, label2], [label1, label3]]
), and a multi-label classifier will be trained. For example, the following code trains a multi-label classifier on the go_emotions dataset:
from datasets import load_dataset
from model2vec.train import StaticModelForClassification
# Initialize a classifier from a pre-trained model
classifier = StaticModelForClassification.from_pretrained(model_name="minishlab/potion-base-32M")
# Load a multi-label dataset
ds = load_dataset("google-research-datasets/go_emotions")
# Inspect some of the labels
print(ds["train"]["labels"][40:50])
# [[0, 15], [15, 18], [16, 27], [27], [7, 13], [10], [20], [27], [27], [27]]
# Train the classifier on text (X) and labels (y)
classifier.fit(ds["train"]["text"], ds["train"]["labels"])
Then, we can evaluate the classifier:
from sklearn import metrics
from sklearn.preprocessing import MultiLabelBinarizer
classification_report = classifier.evaluate(ds["test"]["text"], ds["test"]["labels"], threshold=0.3)
print(classification_report)
# Accuracy: 0.410
# Precision: 0.527
# Recall: 0.410
# F1: 0.439
The scores are competitive with the popular roberta-base-go_emotions model, while our model is orders of magnitude faster.
Persistence
You can turn a classifier into a scikit-learn compatible pipeline, as follows:
pipeline = classifier.to_pipeline()
This pipeline object can be persisted using standard pickle-based methods, such as joblib. This makes it easy to use your model in inferene pipelines (no installing torch!), although joblib
and pickle
should not be used to share models outside of your organization.
If you want to persist your pipeline to the Hugging Face hub, you can use our built-in functions:
pipeline.save_pretrained(path)
pipeline.push_to_hub("my_cool/project")
Later, you can load these as follows:
from model2vec.inference import StaticModelPipeline
pipeline = StaticModelPipeline.from_pretrained("my_cool/project")
Loading pipelines in this way is extremely fast. It takes only 30ms to load a pipeline from disk.
Bring your own architecture
Our training architecture is set up to be extensible, with each task having a specific class. Right now, we only offer StaticModelForClassification
, but in the future we'll also offer regression, etc.
The core functionality of the StaticModelForClassification
is contained in a couple of functions:
construct_head
: This function constructs the classifier on top of the staticmodel. For example, if you want to create a model that has LayerNorm, just subclass, and replace this function. This should be the main function to update if you want to change model behavior.train_test_split
: governs the train test split before classification.prepare_dataset
: Selects thetorch.Dataset
that will be used in theDataloader
during training._encode
: The encoding function used in the model.fit
: contains all the lightning-related fitting logic.
The training of the model is done in a lighting.LightningModule
, which can be modified but is very basic.
Results
We ran extensive benchmarks where we compared our model to several well known architectures. The results can be found in the training results documentation.