# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) ![VLM SFT training procedure](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure.png) ## Overview This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases: - **Single Image + Text** - **Multi-Image + Text** This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly. We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets. ## Understanding the Datasets To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task. ### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text) This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input. The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**. ### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text) The **FanqingM/MMIU-Benchmark** dataset consists of: - **Context:** Included in the system prompt. - **Question:** Provided as part of the user's input. - **Series of Images:** Multiple images related to the question. - **Answer:** The model's expected response. This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs. ## Developing a Fine-Tuning Script for Multimodal SFT In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases. ### Setting Up the Environment Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment: ```bash # Install the required libraries. Futher details: https://huggingface.co/docs/trl/installation pip install -U -q trl bitsandbytes peft hf_xet tensorboard ``` Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required. If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it. To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account. ```bash huggingface-cli login ``` ### **Loading the Data** As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent. This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario. #### **Single Image + Text** ![Single Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_single_image.png) In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`. ```python from datasets import load_dataset dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft" # Load Dataset dataset = load_dataset(dataset_name) ``` #### **Multi-Image + Text (or Interleaving)** ![Multi-Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_multi_image.png) Gemma 3 also supports **Multi-Image + Text** scenarios, where: - The model receives a **list of images** alongside a user message. - The model processes **interleaved images and text** within a conversation. For this dataset, some preprocessing is required before training. ```python from datasets import load_dataset dataset_name = "FanqingM/MMIU-Benchmark" # Load Dataset dataset = load_dataset(dataset_name) ``` After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look: ```python {"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, {"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, ``` Here, `images_list` is a list of images: ```python images_list = [ {"type": "image", "image": }, {"type": "image", "image": }, {"type": "image", "image": }, {"type": "image", "image": }, {"type": "image", "image": }, ] ``` This structure can be translated into code like this: ```python import os import zipfile import io from datasets import DatasetDict from huggingface_hub import hf_hub_download, list_repo_files from PIL import Image dataset_train_split = "test" def format_data(samples: dict[str, any]) -> dict[str, list]: formatted_samples = {"messages": []} for cont in range(len(samples["question"])): images = [] for img_path in samples["input_image_path"][cont]: try: with open(img_path, "rb") as f: img_bytes = f.read() image = Image.open(io.BytesIO(img_bytes)).convert("RGB") images.append({"type": "image", "image": image}) except Exception as e: print(f"Error processing image {img_path}: {e}") continue formatted_samples["messages"].append( [ {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, ] ) return formatted_samples # For multi-image example def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: all_files = list_repo_files(dataset_name, repo_type="dataset") zip_files = [f for f in all_files if f.endswith(".zip")] for zip_filename in zip_files: zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") extract_folder = zip_filename.replace(".zip", "") os.makedirs(extract_folder, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_folder) dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) return dataset dataset = prepare_dataset(dataset, dataset_name, dataset_train_split) ``` With this, your **Multi-Image + Text** dataset is now prepared for training. ### **Preparing for Training** We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models. To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model. ```python import torch from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig model_id = "google/gemma-3-4b-it" # BitsAndBytesConfig int-4 config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_storage=torch.bfloat16, ) # Load model and tokenizer model = AutoModelForImageTextToText.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934) quantization_config=bnb_config ) processor = AutoProcessor.from_pretrained(model_id) processor.tokenizer.padding_side = "right" ``` Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs). ```python from peft import LoraConfig, get_peft_model # Configure QLoRA peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.05, r=16, bias="none", target_modules="all-linear", task_type="CAUSAL_LM", modules_to_save=[ "lm_head", "embed_tokens", ], ) ``` With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs. ```python from trl import SFTConfig training_args = SFTConfig( output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets). num_train_epochs=1, # Set the number of epochs to train the model. per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1 gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1 gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training. optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance. logging_steps=10, # Frequency of logging training progress (log every 10 steps). save_strategy="epoch", # Save checkpoints at the end of each epoch. learning_rate=2e-05, # Learning rate for training. bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations. push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training. report_to="tensorboard", # Automatically report metrics to tensorboard. gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues. dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually. remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing). ) ``` The `collate_fn` is responsible for processing and preparing individual examples to form a batch. Each example in the batch undergoes the following steps: 1. The **chat template** is applied to the text. 2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors. 3. The **labels** for training are set as the `input_ids` of the example. 4. Certain **special tokens** are **masked (ignored)** during loss computation: - `pad_token_id` - `` - `` (corresponding to ID `262144`) This process is similar across different dataset types, with a minor variation in how images are handled: - **Single Image + Text** → A **list of images** is directly processed. - **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images. ```python from PIL import Image # For multi-image cases def process_vision_info(messages: list[dict]) -> list[Image.Image]: image_inputs = [] for msg in messages: content = msg.get("content", []) if not isinstance(content, list): content = [content] for element in content: if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): if "image" in element: image = element["image"] else: image = element if image is not None: image = Image.open(io.BytesIO(image["bytes"])) image_inputs.append(image.convert("RGB")) return image_inputs def collate_fn(examples): texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples] if "images" in examples[0]: # single-image images = [ [img.convert("RGB") for img in example["images"]] for example in examples ] else: # multi-image images = [process_vision_info(example["messages"]) for example in examples] # Tokenize the texts and process the images batch = processor( text=texts, images=images, return_tensors="pt", padding=True ) # Encode texts and images into tensors # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() # Clone input IDs for labels # Mask image tokens image_token_id = [ processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) ] # Mask tokens for not being used in the loss computation labels[labels == processor.tokenizer.pad_token_id] = -100 labels[labels == image_token_id] = -100 labels[labels == 262144] = -100 batch["labels"] = labels return batch # Return the prepared batch ``` ### **Training the Model** With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process. ``` python # Training from trl import SFTTrainer trainer = SFTTrainer( model=model, args=training_args, data_collator=collate_fn, train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"], processing_class=processor, peft_config=peft_config, ) trainer.train() # Save the final model trainer.save_model() ``` We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration. ### Results During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example: * [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft) * [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark) ## Limitations Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results. ## References For further reading and complementary resources, check out the following: - [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) - [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl)