|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from sklearn.metrics import accuracy_score |
|
from transformers import ( |
|
Trainer, |
|
TrainingArguments, |
|
CLIPVisionModel, |
|
CLIPImageProcessor, |
|
) |
|
from torch.utils.data import Dataset |
|
import os |
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
|
|
|
BASE_PATH = './' |
|
TRAIN_CSV = os.path.join(BASE_PATH, 'fairface_label_train.csv') |
|
VAL_CSV = os.path.join(BASE_PATH, 'fairface_label_val.csv') |
|
MODEL_NAME = "openai/clip-vit-large-patch14" |
|
OUTPUT_DIR = "./clip-fairface-finetuned" |
|
|
|
|
|
|
|
train_df = pd.read_csv(TRAIN_CSV) |
|
|
|
|
|
age_labels = sorted(train_df['age'].unique()) |
|
gender_labels = sorted(train_df['gender'].unique()) |
|
race_labels = sorted(train_df['race'].unique()) |
|
|
|
|
|
label_mappings = { |
|
'age': {label: i for i, label in enumerate(age_labels)}, |
|
'gender': {label: i for i, label in enumerate(gender_labels)}, |
|
'race': {label: i for i, label in enumerate(race_labels)}, |
|
} |
|
|
|
NUM_LABELS = { |
|
'age': len(age_labels), |
|
'gender': len(gender_labels), |
|
'race': len(race_labels), |
|
} |
|
|
|
print(f"Number of labels: Age={NUM_LABELS['age']}, Gender={NUM_LABELS['gender']}, Race={NUM_LABELS['race']}") |
|
|
|
|
|
class FairFaceDataset(Dataset): |
|
def __init__(self, csv_file, image_processor, label_maps, base_path): |
|
self.df = pd.read_csv(csv_file) |
|
self.image_processor = image_processor |
|
self.label_maps = label_maps |
|
self.base_path = base_path |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
row = self.df.iloc[idx] |
|
|
|
image_path = os.path.join(self.base_path, row['file']) |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
inputs = {} |
|
inputs['pixel_values'] = self.image_processor(images=image, return_tensors="pt").pixel_values.squeeze(0) |
|
|
|
|
|
inputs['labels'] = { |
|
'age': torch.tensor(self.label_maps['age'][row['age']], dtype=torch.long), |
|
'gender': torch.tensor(self.label_maps['gender'][row['gender']], dtype=torch.long), |
|
'race': torch.tensor(self.label_maps['race'][row['race']], dtype=torch.long), |
|
} |
|
return inputs |
|
|
|
|
|
|
|
class MultiTaskClipVisionModel(nn.Module): |
|
|
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, num_labels): |
|
super(MultiTaskClipVisionModel, self).__init__() |
|
self.vision_model = CLIPVisionModel.from_pretrained(MODEL_NAME) |
|
|
|
|
|
for param in self.vision_model.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
for layer in self.vision_model.vision_model.encoder.layers[-3:]: |
|
for param in layer.parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
hidden_size = self.vision_model.config.hidden_size |
|
self.age_head = nn.Linear(hidden_size, num_labels['age']) |
|
self.gender_head = nn.Linear(hidden_size, num_labels['gender']) |
|
self.race_head = nn.Linear(hidden_size, num_labels['race']) |
|
|
|
|
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
|
"""Activates gradient checkpointing for the underlying vision model.""" |
|
self.vision_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) |
|
|
|
def forward(self, pixel_values, labels=None): |
|
|
|
outputs = self.vision_model(pixel_values=pixel_values) |
|
pooled_output = outputs.pooler_output |
|
|
|
age_logits = self.age_head(pooled_output) |
|
gender_logits = self.gender_head(pooled_output) |
|
race_logits = self.race_head(pooled_output) |
|
|
|
loss = None |
|
|
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
age_loss = loss_fct(age_logits, labels['age']) |
|
gender_loss = loss_fct(gender_logits, labels['gender']) |
|
race_loss = loss_fct(race_logits, labels['race']) |
|
|
|
loss = age_loss + gender_loss + race_loss |
|
|
|
return { |
|
'loss': loss, |
|
'logits': { |
|
'age': age_logits, |
|
'gender': gender_logits, |
|
'race': race_logits, |
|
}, |
|
} |
|
|
|
|
|
def collate_fn(batch): |
|
|
|
pixel_values = torch.stack([item['pixel_values'] for item in batch]) |
|
labels = { |
|
'age': torch.tensor([item['labels']['age'] for item in batch], dtype=torch.long), |
|
'gender': torch.tensor([item['labels']['gender'] for item in batch], dtype=torch.long), |
|
'race': torch.tensor([item['labels']['race'] for item in batch], dtype=torch.long), |
|
} |
|
return {'pixel_values': pixel_values, 'labels': labels} |
|
|
|
def compute_metrics(p): |
|
|
|
logits = p.predictions |
|
labels = p.label_ids |
|
|
|
|
|
age_preds = logits['age'].argmax(-1) |
|
gender_preds = logits['gender'].argmax(-1) |
|
race_preds = logits['race'].argmax(-1) |
|
|
|
age_labels = labels['age'] |
|
gender_labels = labels['gender'] |
|
race_labels = labels['race'] |
|
|
|
|
|
return { |
|
'age_accuracy': accuracy_score(age_labels, age_preds), |
|
'gender_accuracy': accuracy_score(gender_labels, gender_preds), |
|
'race_accuracy': accuracy_score(race_labels, race_preds), |
|
} |
|
|
|
|
|
def main(): |
|
|
|
image_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) |
|
model = MultiTaskClipVisionModel(num_labels=NUM_LABELS) |
|
|
|
|
|
train_dataset = FairFaceDataset( |
|
csv_file=TRAIN_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH |
|
) |
|
val_dataset = FairFaceDataset( |
|
csv_file=VAL_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH |
|
) |
|
|
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=OUTPUT_DIR, |
|
num_train_epochs=5, |
|
|
|
per_device_train_batch_size=24, |
|
per_device_eval_batch_size=32, |
|
|
|
gradient_accumulation_steps=22, |
|
|
|
gradient_checkpointing=True, |
|
warmup_steps=500, |
|
weight_decay=0.01, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
evaluation_strategy="steps", |
|
eval_steps=250, |
|
save_strategy="steps", |
|
save_steps=250, |
|
load_best_model_at_end=True, |
|
metric_for_best_model='gender_accuracy', |
|
save_total_limit=3, |
|
fp16=True, |
|
remove_unused_columns=False, |
|
report_to="none", |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset, |
|
data_collator=collate_fn, |
|
compute_metrics=compute_metrics, |
|
) |
|
|
|
|
|
print("Starting model training...") |
|
trainer.train() |
|
|
|
|
|
print("Saving the best model...") |
|
trainer.save_model(os.path.join(OUTPUT_DIR, "best_model")) |
|
image_processor.save_pretrained(os.path.join(OUTPUT_DIR, "best_model")) |
|
|
|
print("Training complete!") |
|
|
|
if __name__ == "__main__": |
|
main() |