import pandas as pd import torch import torch.nn as nn from PIL import Image from sklearn.metrics import accuracy_score from transformers import ( Trainer, TrainingArguments, CLIPVisionModel, CLIPImageProcessor, ) from torch.utils.data import Dataset import os os.environ["WANDB_DISABLED"] = "true" # --- 1. Configuration --- # Define paths and model name BASE_PATH = './' # Assumes the script is run from the 'fairface' directory TRAIN_CSV = os.path.join(BASE_PATH, 'fairface_label_train.csv') VAL_CSV = os.path.join(BASE_PATH, 'fairface_label_val.csv') MODEL_NAME = "openai/clip-vit-large-patch14" OUTPUT_DIR = "./clip-fairface-finetuned" # --- 2. Load and Prepare Label Mappings --- # Load training data to create consistent label-to-ID mappings train_df = pd.read_csv(TRAIN_CSV) # Create sorted unique label lists to ensure consistent mapping age_labels = sorted(train_df['age'].unique()) gender_labels = sorted(train_df['gender'].unique()) race_labels = sorted(train_df['race'].unique()) # Create label-to-ID mappings for each task label_mappings = { 'age': {label: i for i, label in enumerate(age_labels)}, 'gender': {label: i for i, label in enumerate(gender_labels)}, 'race': {label: i for i, label in enumerate(race_labels)}, } NUM_LABELS = { 'age': len(age_labels), 'gender': len(gender_labels), 'race': len(race_labels), } print(f"Number of labels: Age={NUM_LABELS['age']}, Gender={NUM_LABELS['gender']}, Race={NUM_LABELS['race']}") # --- 3. Custom Dataset --- class FairFaceDataset(Dataset): def __init__(self, csv_file, image_processor, label_maps, base_path): self.df = pd.read_csv(csv_file) self.image_processor = image_processor self.label_maps = label_maps self.base_path = base_path def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] # Construct the full path to the image image_path = os.path.join(self.base_path, row['file']) image = Image.open(image_path).convert("RGB") # Process the image inputs = {} inputs['pixel_values'] = self.image_processor(images=image, return_tensors="pt").pixel_values.squeeze(0) # Process labels into a dictionary of tensors inputs['labels'] = { 'age': torch.tensor(self.label_maps['age'][row['age']], dtype=torch.long), 'gender': torch.tensor(self.label_maps['gender'][row['gender']], dtype=torch.long), 'race': torch.tensor(self.label_maps['race'][row['race']], dtype=torch.long), } return inputs # --- 4. Custom Model Definition --- # --- 4. Custom Model Definition (Corrected for Gradient Checkpointing) --- class MultiTaskClipVisionModel(nn.Module): # Add this class attribute to signal to the Trainer that we support this supports_gradient_checkpointing = True def __init__(self, num_labels): super(MultiTaskClipVisionModel, self).__init__() self.vision_model = CLIPVisionModel.from_pretrained(MODEL_NAME) # Freeze all parameters of the vision model first for param in self.vision_model.parameters(): param.requires_grad = False # Unfreeze the last few layers for fine-tuning. for layer in self.vision_model.vision_model.encoder.layers[-3:]: # Unfreeze last 3 transformer layers for param in layer.parameters(): param.requires_grad = True # Define classification heads for each task hidden_size = self.vision_model.config.hidden_size self.age_head = nn.Linear(hidden_size, num_labels['age']) self.gender_head = nn.Linear(hidden_size, num_labels['gender']) self.race_head = nn.Linear(hidden_size, num_labels['race']) # ADD THIS METHOD: This will be called by the Trainer def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """Activates gradient checkpointing for the underlying vision model.""" self.vision_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) def forward(self, pixel_values, labels=None): # The forward pass now works seamlessly with gradient checkpointing enabled outputs = self.vision_model(pixel_values=pixel_values) pooled_output = outputs.pooler_output age_logits = self.age_head(pooled_output) gender_logits = self.gender_head(pooled_output) race_logits = self.race_head(pooled_output) loss = None # If labels are provided, calculate the combined loss if labels is not None: loss_fct = nn.CrossEntropyLoss() age_loss = loss_fct(age_logits, labels['age']) gender_loss = loss_fct(gender_logits, labels['gender']) race_loss = loss_fct(race_logits, labels['race']) # Total loss is the sum of individual task losses loss = age_loss + gender_loss + race_loss return { 'loss': loss, 'logits': { 'age': age_logits, 'gender': gender_logits, 'race': race_logits, }, } # --- 5. Data Collator and Metrics --- def collate_fn(batch): # Stacks pixel values and organizes labels into a dictionary of tensors pixel_values = torch.stack([item['pixel_values'] for item in batch]) labels = { 'age': torch.tensor([item['labels']['age'] for item in batch], dtype=torch.long), 'gender': torch.tensor([item['labels']['gender'] for item in batch], dtype=torch.long), 'race': torch.tensor([item['labels']['race'] for item in batch], dtype=torch.long), } return {'pixel_values': pixel_values, 'labels': labels} def compute_metrics(p): # p is an EvalPrediction object containing predictions and label_ids logits = p.predictions labels = p.label_ids # Extract predictions and labels for each task age_preds = logits['age'].argmax(-1) gender_preds = logits['gender'].argmax(-1) race_preds = logits['race'].argmax(-1) age_labels = labels['age'] gender_labels = labels['gender'] race_labels = labels['race'] # Calculate accuracy for each task return { 'age_accuracy': accuracy_score(age_labels, age_preds), 'gender_accuracy': accuracy_score(gender_labels, gender_preds), 'race_accuracy': accuracy_score(race_labels, race_preds), } # --- 6. Trainer Setup and Execution --- def main(): # Initialize the image processor and our custom model image_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) model = MultiTaskClipVisionModel(num_labels=NUM_LABELS) # Initialize the training and validation datasets train_dataset = FairFaceDataset( csv_file=TRAIN_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH ) val_dataset = FairFaceDataset( csv_file=VAL_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH ) # Define the training arguments # In your main() function, replace the old TrainingArguments with this one # Define the training arguments training_args = TrainingArguments( output_dir=OUTPUT_DIR, num_train_epochs=5, # Set a batch size that fits in memory per_device_train_batch_size=24, per_device_eval_batch_size=32, # Evaluation does not need accumulation and can use a larger batch size # Set accumulation steps to reach the desired effective batch size (24 * 22 = 528) gradient_accumulation_steps=22, # Enable gradient checkpointing to save more memory gradient_checkpointing=True, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, # Log more frequently to see progress within a large effective batch evaluation_strategy="steps", eval_steps=250, # You might want to evaluate less frequently with larger batches save_strategy="steps", save_steps=250, load_best_model_at_end=True, metric_for_best_model='gender_accuracy', save_total_limit=3, fp16=True, # Mixed-precision training is essential for large models remove_unused_columns=False, report_to="none", # Disables wandb logging ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=collate_fn, compute_metrics=compute_metrics, ) # Start training print("Starting model training...") trainer.train() # Save the final model and processor print("Saving the best model...") trainer.save_model(os.path.join(OUTPUT_DIR, "best_model")) image_processor.save_pretrained(os.path.join(OUTPUT_DIR, "best_model")) print("Training complete!") if __name__ == "__main__": main()