Inspired by: https://huggingface.co/blog/fine-tune-vit

# Fine-Tuning Vision Transformers for Image Classification

Just as transformers-based models have revolutionized NLP, we're now seeing an explosion of papers applying them to all sorts of other domains. One of the most revolutionary of these was the Vision Transformer (ViT), which was introduced in [June 2021](https://arxiv.org/abs/2010.11929) by a team of researchers at Google Brain.

This paper explored how you can tokenize images, just as you would tokenize sentences, so that they can be passed to transformer models for training. Its quite a simple concept, really...

1. Split an image into a grid of sub-image patches
1. Embed each patch with a linear projection
1. Each embedded patch becomes a token, and the resulting sequence of embedded patches is the sequence you pass to the model.

![vit_figure.png](https://raw.githubusercontent.com/google-research/vision_transformer/main/vit_figure.png)


It turns out that once you've done the above, you can pre-train and finetune transformers just as you're used to with NLP tasks. Pretty sweet üòé.

In [None]:
%%capture

! pip install datasets transformers evaluate

## Load dataset

In [None]:
from datasets import load_dataset

dataset_name = "jonathan-roberts1/Satellite-Images-of-Hurricane-Damage"

def get_ds():
  ds = load_dataset(dataset_name)
  ds = ds["train"].train_test_split(test_size=0.5)
  ds["train"]["label"].count(1), ds["test"]["label"].count(0)
  ds_ = ds["test"].train_test_split(test_size=0.5)
  ds["validation"] = ds_["train"]
  ds["test"] = ds_["test"]
  return ds

In [None]:
ds = get_ds()

In [None]:
image = ds['train'][400]['image']

## Loading ViT Feature Extractor

Now that we know what our images look like and have a better understanding of the problem we're trying to solve, let's see how we can prepare these images for our model.

When ViT models are trained, specific transformations are applied to images being fed into them. Use the wrong transformations on your image and the model won't be able to understand what it's seeing! üñº ‚û°Ô∏è üî¢

To make sure we apply the correct transformations, we will use a [`ViTFeatureExtractor`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=classlabel#datasets.ClassLabel.int2str) initialized with a configuration that was saved along with the pretrained model we plan to use. In our case, we'll be using the [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) model, so lets load its feature extractor from the ü§ó Hub.

In [66]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



If we print a feature extractor, we can see its configuration.

In [67]:
feature_extractor

ViTFeatureExtractor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTFeatureExtractor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

To process an image, simply pass it to the feature extractor's call function. This will return a dict containing `pixel values`, which is the numeric representation of your image that we'll pass to the model.

We get a numpy array by default, but if we add the `return_tensors='pt'` argument, we'll get back `torch` tensors instead.


In [68]:
feature_extractor(image, return_tensors='pt')

{'pixel_values': tensor([[[[-0.4118, -0.4039, -0.3882,  ..., -0.1137, -0.1137, -0.1137],
          [-0.4353, -0.4275, -0.4118,  ..., -0.1294, -0.1373, -0.1373],
          [-0.4745, -0.4667, -0.4431,  ..., -0.1608, -0.1686, -0.1686],
          ...,
          [-0.6078, -0.6000, -0.5765,  ..., -0.5294, -0.5373, -0.5451],
          [-0.6000, -0.6000, -0.5843,  ..., -0.5216, -0.5294, -0.5373],
          [-0.6000, -0.6000, -0.5922,  ..., -0.5216, -0.5294, -0.5373]],

         [[-0.3255, -0.3176, -0.3020,  ..., -0.0824, -0.0824, -0.0824],
          [-0.3490, -0.3412, -0.3255,  ..., -0.0980, -0.1059, -0.1059],
          [-0.3882, -0.3804, -0.3569,  ..., -0.1294, -0.1373, -0.1373],
          ...,
          [-0.6000, -0.5922, -0.5686,  ..., -0.4118, -0.4196, -0.4275],
          [-0.5922, -0.5922, -0.5765,  ..., -0.4039, -0.4118, -0.4196],
          [-0.5922, -0.5922, -0.5843,  ..., -0.4039, -0.4118, -0.4196]],

         [[-0.4510, -0.4431, -0.4275,  ..., -0.1922, -0.1922, -0.1922],
          [-0

## Processing the Dataset

Now that we know how to read in images and transform them into inputs, let's write a function that will put those two things together to process a single example from the dataset.

In [69]:
def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['label'] = example['label']
    return inputs

In [70]:
process_example(ds['train'][0])

{'pixel_values': tensor([[[[-0.3569, -0.3490, -0.3412,  ..., -0.5294, -0.5137, -0.5059],
          [-0.3412, -0.3333, -0.3255,  ..., -0.5294, -0.5216, -0.5137],
          [-0.3176, -0.3098, -0.3020,  ..., -0.5373, -0.5294, -0.5216],
          ...,
          [-0.5686, -0.5765, -0.5843,  ..., -0.4353, -0.4431, -0.4510],
          [-0.5686, -0.5686, -0.5765,  ..., -0.3882, -0.3882, -0.3961],
          [-0.5686, -0.5686, -0.5686,  ..., -0.3569, -0.3569, -0.3569]],

         [[-0.3961, -0.3882, -0.3804,  ..., -0.4196, -0.4039, -0.3961],
          [-0.3804, -0.3725, -0.3647,  ..., -0.4196, -0.4118, -0.4039],
          [-0.3569, -0.3490, -0.3412,  ..., -0.4275, -0.4196, -0.4118],
          ...,
          [-0.3882, -0.3961, -0.4039,  ..., -0.3490, -0.3569, -0.3647],
          [-0.3882, -0.3882, -0.3961,  ..., -0.3020, -0.3020, -0.3098],
          [-0.3882, -0.3882, -0.3882,  ..., -0.2706, -0.2706, -0.2706]],

         [[-0.5686, -0.5608, -0.5529,  ..., -0.6235, -0.6078, -0.6000],
          [-0

While we could call `ds.map` and apply this to every example at once, this can be very slow, especially if you use a larger dataset. Instead, we'll apply a ***transform*** to the dataset. Transforms are only applied to examples as you index them.

First, though, we'll need to update our last function to accept a batch of data, as that's what `ds.with_transform` expects.

In [73]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['label'] = example_batch['label']
    return inputs

prepared_ds = ds.with_transform(transform)

We can directly apply this to our dataset using `ds.with_transform(transform)`.

In [74]:
prepared_ds = ds.with_transform(transform)

Now, whenever we get an example from the dataset, our transform will be
applied in real time (on both samples and slices, as shown below)

In [75]:
prepared_ds['train'][0:2]

{'pixel_values': tensor([[[[-0.0275, -0.0118,  0.0118,  ...,  0.1216,  0.0980,  0.0824],
          [-0.0196, -0.0039,  0.0196,  ...,  0.1137,  0.0980,  0.0902],
          [-0.0039,  0.0118,  0.0431,  ...,  0.0980,  0.0980,  0.0980],
          ...,
          [-0.1765, -0.1059,  0.0196,  ...,  0.0275,  0.0118,  0.0039],
          [-0.2000, -0.1686, -0.1059,  ...,  0.0275,  0.0353,  0.0353],
          [-0.2157, -0.2078, -0.1843,  ...,  0.0275,  0.0510,  0.0588]],

         [[-0.0431, -0.0275, -0.0039,  ...,  0.1137,  0.0902,  0.0745],
          [-0.0353, -0.0196,  0.0039,  ...,  0.1059,  0.0902,  0.0824],
          [-0.0196, -0.0039,  0.0275,  ...,  0.0902,  0.0902,  0.0902],
          ...,
          [-0.1843, -0.1137,  0.0118,  ...,  0.0588,  0.0431,  0.0353],
          [-0.2078, -0.1765, -0.1137,  ...,  0.0588,  0.0667,  0.0667],
          [-0.2235, -0.2157, -0.1922,  ...,  0.0588,  0.0824,  0.0902]],

         [[-0.2235, -0.2078, -0.1843,  ..., -0.1294, -0.1529, -0.1686],
          [-0

# Training and Evaluation

The data is processed and we are ready to start setting up the training pipeline. We will make use of ü§ó's Trainer, but that'll require us to do a few things first:

- Define a collate function.

- Define an evaluation metric. During training, the model should be evaluated on its prediction accuracy. We should define a compute_metrics function accordingly.

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the evaluation data and verify that it has indeed learned to correctly classify our images.

In [76]:
from huggingface_hub import notebook_login

# log into huggingface to upload the model to your account
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

### Define our data collator

Batches are coming in as lists of dicts, so we just unpack + stack those into batch tensors.

We return a batch `dict` from our `collate_fn` so we can simply `**unpack` the inputs to our model later. ‚ú®

In [84]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

### Define an evaluation metric

Here, we load the [accuracy](https://huggingface.co/metrics/accuracy) metric from `datasets`, and then write a function that takes in a model prediction + computes the accuracy.

In [85]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

Now we can load our pretrained model. We'll add `num_labels` on init to make sure the model creates a classification head with the right number of units. We'll also include the `id2label` and `label2id` mappings so we have human readable labels in the ü§ó hub widget if we choose to `push_to_hub`.

In [93]:
from transformers import ViTForImageClassification

labels = ds['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

In [95]:
inputs = feature_extractor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[str(predicted_label)])

undamaged buildings


We're almost ready to train! The last thing we'll do before that is set up the training configuration by defining [`TrainingArguments`](https://huggingface.co/docs/transformers/v4.16.2/en/main_classes/trainer#transformers.TrainingArguments).

Most of these are pretty self-explanatory, but one that is quite important here is `remove_unused_columns=False`. This one will drop any features not used by the model's call function. By default it's `True` because usually its ideal to drop unused feature columns, as it makes it easier to unpack inputs into the model's call function. But, in our case, we need the unused features ('image' in particular) in order to create 'pixel_values'.

What I'm trying to say is that you'll have a bad time if you forget to set `remove_unused_columns=False`.

In [87]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="hurricane_model",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=True,
  report_to='tensorboard',
  load_best_model_at_end=True,
)



Now, all instances can be passed to Trainer and we are ready to start training!



In [88]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    processing_class=feature_extractor,
)

  trainer = Trainer(


In [89]:
metrics = trainer.evaluate(prepared_ds['validation'])
metrics

{'eval_loss': 0.6793570518493652,
 'eval_model_preparation_time': 0.0051,
 'eval_accuracy': 0.5884,
 'eval_runtime': 13.2034,
 'eval_samples_per_second': 189.346,
 'eval_steps_per_second': 23.706}

In [90]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss,Model Preparation Time,Accuracy
100,0.1118,0.148582,0.0051,0.9476
200,0.1112,0.070119,0.0051,0.9752
300,0.0694,0.060849,0.0051,0.9808
400,0.0048,0.091668,0.0051,0.9744
500,0.036,0.055198,0.0051,0.9836
600,0.0594,0.054691,0.0051,0.9808
700,0.0115,0.06273,0.0051,0.9844
800,0.0016,0.029573,0.0051,0.9936
900,0.004,0.032514,0.0051,0.9916
1000,0.0009,0.022371,0.0051,0.9948


***** train metrics *****
  epoch                    =          4.0
  total_flos               = 1443400785GF
  train_loss               =       0.0559
  train_runtime            =   0:08:53.40
  train_samples_per_second =       37.495
  train_steps_per_second   =        2.347


In [91]:
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  epoch                       =        4.0
  eval_accuracy               =     0.9948
  eval_loss                   =     0.0224
  eval_model_preparation_time =     0.0051
  eval_runtime                = 0:00:12.58
  eval_samples_per_second     =    198.586
  eval_steps_per_second       =     24.863


In [92]:
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": dataset_name,
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('üçª cheers', **kwargs)
    print("pushed to hub")
else:
    trainer.create_model_card(**kwargs)

events.out.tfevents.1734166523.592abeefaa6c.3097.1:   0%|          | 0.00/477 [00:00<?, ?B/s]

pushed to hub


# Done :)