syntheticbot's picture
Upload 5 files
9a5479a verified
raw
history blame
9.01 kB
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from sklearn.metrics import accuracy_score
from transformers import (
Trainer,
TrainingArguments,
CLIPVisionModel,
CLIPImageProcessor,
)
from torch.utils.data import Dataset
import os
os.environ["WANDB_DISABLED"] = "true"
# --- 1. Configuration ---
# Define paths and model name
BASE_PATH = './' # Assumes the script is run from the 'fairface' directory
TRAIN_CSV = os.path.join(BASE_PATH, 'fairface_label_train.csv')
VAL_CSV = os.path.join(BASE_PATH, 'fairface_label_val.csv')
MODEL_NAME = "openai/clip-vit-large-patch14"
OUTPUT_DIR = "./clip-fairface-finetuned"
# --- 2. Load and Prepare Label Mappings ---
# Load training data to create consistent label-to-ID mappings
train_df = pd.read_csv(TRAIN_CSV)
# Create sorted unique label lists to ensure consistent mapping
age_labels = sorted(train_df['age'].unique())
gender_labels = sorted(train_df['gender'].unique())
race_labels = sorted(train_df['race'].unique())
# Create label-to-ID mappings for each task
label_mappings = {
'age': {label: i for i, label in enumerate(age_labels)},
'gender': {label: i for i, label in enumerate(gender_labels)},
'race': {label: i for i, label in enumerate(race_labels)},
}
NUM_LABELS = {
'age': len(age_labels),
'gender': len(gender_labels),
'race': len(race_labels),
}
print(f"Number of labels: Age={NUM_LABELS['age']}, Gender={NUM_LABELS['gender']}, Race={NUM_LABELS['race']}")
# --- 3. Custom Dataset ---
class FairFaceDataset(Dataset):
def __init__(self, csv_file, image_processor, label_maps, base_path):
self.df = pd.read_csv(csv_file)
self.image_processor = image_processor
self.label_maps = label_maps
self.base_path = base_path
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
# Construct the full path to the image
image_path = os.path.join(self.base_path, row['file'])
image = Image.open(image_path).convert("RGB")
# Process the image
inputs = {}
inputs['pixel_values'] = self.image_processor(images=image, return_tensors="pt").pixel_values.squeeze(0)
# Process labels into a dictionary of tensors
inputs['labels'] = {
'age': torch.tensor(self.label_maps['age'][row['age']], dtype=torch.long),
'gender': torch.tensor(self.label_maps['gender'][row['gender']], dtype=torch.long),
'race': torch.tensor(self.label_maps['race'][row['race']], dtype=torch.long),
}
return inputs
# --- 4. Custom Model Definition ---
# --- 4. Custom Model Definition (Corrected for Gradient Checkpointing) ---
class MultiTaskClipVisionModel(nn.Module):
# Add this class attribute to signal to the Trainer that we support this
supports_gradient_checkpointing = True
def __init__(self, num_labels):
super(MultiTaskClipVisionModel, self).__init__()
self.vision_model = CLIPVisionModel.from_pretrained(MODEL_NAME)
# Freeze all parameters of the vision model first
for param in self.vision_model.parameters():
param.requires_grad = False
# Unfreeze the last few layers for fine-tuning.
for layer in self.vision_model.vision_model.encoder.layers[-3:]: # Unfreeze last 3 transformer layers
for param in layer.parameters():
param.requires_grad = True
# Define classification heads for each task
hidden_size = self.vision_model.config.hidden_size
self.age_head = nn.Linear(hidden_size, num_labels['age'])
self.gender_head = nn.Linear(hidden_size, num_labels['gender'])
self.race_head = nn.Linear(hidden_size, num_labels['race'])
# ADD THIS METHOD: This will be called by the Trainer
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""Activates gradient checkpointing for the underlying vision model."""
self.vision_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
def forward(self, pixel_values, labels=None):
# The forward pass now works seamlessly with gradient checkpointing enabled
outputs = self.vision_model(pixel_values=pixel_values)
pooled_output = outputs.pooler_output
age_logits = self.age_head(pooled_output)
gender_logits = self.gender_head(pooled_output)
race_logits = self.race_head(pooled_output)
loss = None
# If labels are provided, calculate the combined loss
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
age_loss = loss_fct(age_logits, labels['age'])
gender_loss = loss_fct(gender_logits, labels['gender'])
race_loss = loss_fct(race_logits, labels['race'])
# Total loss is the sum of individual task losses
loss = age_loss + gender_loss + race_loss
return {
'loss': loss,
'logits': {
'age': age_logits,
'gender': gender_logits,
'race': race_logits,
},
}
# --- 5. Data Collator and Metrics ---
def collate_fn(batch):
# Stacks pixel values and organizes labels into a dictionary of tensors
pixel_values = torch.stack([item['pixel_values'] for item in batch])
labels = {
'age': torch.tensor([item['labels']['age'] for item in batch], dtype=torch.long),
'gender': torch.tensor([item['labels']['gender'] for item in batch], dtype=torch.long),
'race': torch.tensor([item['labels']['race'] for item in batch], dtype=torch.long),
}
return {'pixel_values': pixel_values, 'labels': labels}
def compute_metrics(p):
# p is an EvalPrediction object containing predictions and label_ids
logits = p.predictions
labels = p.label_ids
# Extract predictions and labels for each task
age_preds = logits['age'].argmax(-1)
gender_preds = logits['gender'].argmax(-1)
race_preds = logits['race'].argmax(-1)
age_labels = labels['age']
gender_labels = labels['gender']
race_labels = labels['race']
# Calculate accuracy for each task
return {
'age_accuracy': accuracy_score(age_labels, age_preds),
'gender_accuracy': accuracy_score(gender_labels, gender_preds),
'race_accuracy': accuracy_score(race_labels, race_preds),
}
# --- 6. Trainer Setup and Execution ---
def main():
# Initialize the image processor and our custom model
image_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
model = MultiTaskClipVisionModel(num_labels=NUM_LABELS)
# Initialize the training and validation datasets
train_dataset = FairFaceDataset(
csv_file=TRAIN_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH
)
val_dataset = FairFaceDataset(
csv_file=VAL_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH
)
# Define the training arguments
# In your main() function, replace the old TrainingArguments with this one
# Define the training arguments
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=5,
# Set a batch size that fits in memory
per_device_train_batch_size=24,
per_device_eval_batch_size=32, # Evaluation does not need accumulation and can use a larger batch size
# Set accumulation steps to reach the desired effective batch size (24 * 22 = 528)
gradient_accumulation_steps=22,
# Enable gradient checkpointing to save more memory
gradient_checkpointing=True,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10, # Log more frequently to see progress within a large effective batch
evaluation_strategy="steps",
eval_steps=250, # You might want to evaluate less frequently with larger batches
save_strategy="steps",
save_steps=250,
load_best_model_at_end=True,
metric_for_best_model='gender_accuracy',
save_total_limit=3,
fp16=True, # Mixed-precision training is essential for large models
remove_unused_columns=False,
report_to="none", # Disables wandb logging
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collate_fn,
compute_metrics=compute_metrics,
)
# Start training
print("Starting model training...")
trainer.train()
# Save the final model and processor
print("Saving the best model...")
trainer.save_model(os.path.join(OUTPUT_DIR, "best_model"))
image_processor.save_pretrained(os.path.join(OUTPUT_DIR, "best_model"))
print("Training complete!")
if __name__ == "__main__":
main()