sakshi7502's picture
Upload 64 files
6376749 verified
import random
from typing import Literal, List
from collections import Counter
from dataclasses import dataclass, field
from functools import partial
import torch
import evaluate
import numpy as np
from torch import optim
from datasets import load_dataset
import transformers
from transformers import (
Trainer,
TrainingArguments,
HfArgumentParser,
AutoImageProcessor,
AutoModelForImageClassification,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from torchvision.transforms import (
Compose,
Normalize,
Resize,
ToTensor,
)
from peft import get_peft_model, VeraConfig, BOFTConfig, LoraConfig
import sys
sys.path.append("../")
from svft.svft_layers import *
##########################
# Metrics
##########################
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
##########################
# Utils
##########################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def reset_seed(SEED=0):
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
transformers.set_seed(SEED)
def get_trainable_params_dict(model):
total_p = sum(p.numel() for p in model.parameters())
trainable_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
clf_trainable_p = sum(
p.numel()
for n, p in model.named_parameters()
if p.requires_grad and "classifier" in n
)
other_p = trainable_p - clf_trainable_p
return {
"total_p": total_p,
"trainable_p": trainable_p,
"clf_trainable_p": clf_trainable_p,
"other_p": other_p,
}
def print_trainable_parameters(model):
params_dict = get_trainable_params_dict(model)
total_p = params_dict["total_p"]
trainable_p = params_dict["trainable_p"]
clf_trainable_p = params_dict["clf_trainable_p"]
other_p = params_dict["other_p"]
print(
f"Total params: {total_p} | Trainable params: {trainable_p} | Trainable%: {trainable_p/total_p*100:.2f}%"
)
print(
f"Clf Trainable params: {clf_trainable_p} | Clf Trainable%: {clf_trainable_p/total_p*100:.2f}%"
)
print(
f"FT Trainable params: {other_p} | FT Trainable%: {other_p/total_p*100:.2f}%"
)
print()
##########################
# Dataset Utilities
##########################
label_key = "label"
image_path_key = "image"
def collate_fn(examples):
pixel_values = torch.stack(
[torch.Tensor(example["pixel_values"]) for example in examples]
)
labels = torch.tensor([example[label_key] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
def preprocess(example_batch, transform_fn):
example_batch["pixel_values"] = [
transform_fn(image.convert("RGB")) for image in example_batch[image_path_key]
]
return example_batch
def get_transforms(image_processor):
if "height" in image_processor.size:
return Compose(
[
Resize((image_processor.size["height"], image_processor.size["width"])),
ToTensor(),
Normalize(
mean=image_processor.image_mean, std=image_processor.image_std
),
]
)
elif "height" in image_processor.crop_size:
return Compose(
[
Resize(
(
image_processor.crop_size["height"],
image_processor.crop_size["width"],
)
),
ToTensor(),
Normalize(
mean=image_processor.image_mean, std=image_processor.image_std
),
]
)
else:
raise ValueError("Unknown image processor")
def get_ids_and_labels_from_dataset(dataset):
labels = set(dataset[label_key])
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = i
id2label[i] = label
return label2id, id2label
def sampled_balanced_train_val(dataset, num_train_per_label=10, num_val_per_label=2):
label_counter = Counter(dataset[label_key])
inst_train = num_train_per_label
inst_val = num_val_per_label
assert min(label_counter.values()) >= inst_train + inst_val
label_counter = Counter()
train_ids = []
val_ids = []
labels = list(enumerate(dataset[label_key]))
random.shuffle(labels)
for i, l in labels:
if label_counter[l] < inst_train:
train_ids.append(i)
elif label_counter[l] < inst_train + inst_val:
val_ids.append(i)
else:
continue
label_counter[l] += 1
return dataset.select(train_ids), dataset.select(val_ids)
DATASET_NAME_TO_URL = {
"cifar100": "cifar100",
"food101": "ethz/food101",
"flowers102": "dpdl-benchmark/oxford_flowers102",
"resisc45": "timm/resisc45",
}
def get_dataset(dataset_name):
if dataset_name == "cifar100":
dataset_url = DATASET_NAME_TO_URL[dataset_name]
dataset = load_dataset(dataset_url, split="train")
dataset = dataset.rename_column("fine_label", label_key)
dataset = dataset.rename_column("img", image_path_key)
dataset_test = load_dataset(dataset_url, split="test")
dataset_test = dataset_test.rename_column("fine_label", label_key)
dataset_test = dataset_test.rename_column("img", image_path_key)
dataset_train, dataset_val = sampled_balanced_train_val(dataset)
return dataset_train, dataset_val, dataset_test
elif dataset_name == "food101":
dataset_url = DATASET_NAME_TO_URL[dataset_name]
dataset = load_dataset(dataset_url, split="train")
dataset_test = load_dataset(dataset_url, split="validation")
dataset_train, dataset_val = sampled_balanced_train_val(dataset)
return dataset_train, dataset_val, dataset_test
elif dataset_name in {"flowers102", "resisc45"}:
dataset_url = DATASET_NAME_TO_URL[dataset_name]
dataset_train = load_dataset(dataset_url, split="train")
dataset_val = load_dataset(dataset_url, split="validation")
dataset_test = load_dataset(dataset_url, split="test")
dataset_train, _ = sampled_balanced_train_val(
dataset_train, num_train_per_label=10, num_val_per_label=0
)
_, dataset_val = sampled_balanced_train_val(
dataset_val, num_train_per_label=0, num_val_per_label=2
)
return dataset_train, dataset_val, dataset_test
else:
raise ValueError("Unknown dataset name")
##########################
# Finetuning Config
##########################
MODEL_NAME_TO_URL = {
"dino-v2-large": "facebook/dinov2-large",
"vit-base": "google/vit-base-patch16-224-in21k",
"vit-large": "google/vit-large-patch16-224-in21k",
}
def get_target_modules(model_name, finetuning_method):
if model_name == "dino-v2-large":
if finetuning_method in {"vera", "svft"}:
return [
"query",
"key",
]
else:
return "all-linear"
elif model_name in {"vit-base", "vit-large"}:
if finetuning_method == "head":
return []
else:
return [
"query",
"value",
]
else:
raise ValueError("Unknown model name")
def get_classifier_modules(model_name):
if model_name in {"dino-v2-large", "vit-base", "vit-large"}:
return [
"classifier",
]
else:
raise ValueError("Unknown model name")
@dataclass
class ScriptArguments:
results_json: str = field(
default="results.json", metadata={"help": "Results json file"}
)
model_name: Literal["dino-v2-large", "vit-base", "vit-large"] = field(
default="vit-base", metadata={"help": "Model name"}
)
dataset_name: Literal[
"cifar100",
"food101",
"flowers102",
"resisc45",
] = field(default="cifar100", metadata={"help": "Dataset name"})
finetuning_method: Literal[
"vera", "boft", "lora", "dora", "svft", "head", "full"
] = field(default="head", metadata={"help": "Finetuning method"})
clf_learning_rate: float = field(
default=1e-3, metadata={"help": "Classifier learning rate"}
)
other_learning_rate: float = field(
default=1e-4, metadata={"help": "Other learning rate"}
)
## BOFT
boft_block_size: int = field(default=0, metadata={"help": "BOFT block size (m)"})
boft_n_butterfly_factor: int = field(
default=0, metadata={"help": "BOFT n butterfly factor (b)"}
)
## VeRA
vera_rank: int = field(default=0, metadata={"help": "Vera rank"})
## LoRA and DoRA
lora_rank: int = field(default=0, metadata={"help": "Lora rank"})
## SVFT rank
svft_rank: int = field(default=0, metadata={"help": "SVFT rank"})
## Target Modules
target_modules: List[str] = field(
default_factory=list,
metadata={"help": "Target modules for finetuning"},
)
def main():
import json
import wandb
from pprint import pprint
wandb.init(mode="disabled")
parser = HfArgumentParser((ScriptArguments, TrainingArguments))
script_args, training_args = parser.parse_args_into_dataclasses()
reset_seed(training_args.seed)
## Load dataset
dataset_train, dataset_val, dataset_test = get_dataset(script_args.dataset_name)
label2id, id2label = get_ids_and_labels_from_dataset(dataset_train)
# Set image transforms
model_name = script_args.model_name
model_url = MODEL_NAME_TO_URL[model_name]
image_processor = AutoImageProcessor.from_pretrained(model_url)
transform_fn = get_transforms(image_processor)
dataset_train.set_transform(lambda x: preprocess(x, transform_fn))
dataset_val.set_transform(lambda x: preprocess(x, transform_fn))
dataset_test.set_transform(lambda x: preprocess(x, transform_fn))
# Load model
model = AutoModelForImageClassification.from_pretrained(
model_url,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes=True,
).to(device)
print_trainable_parameters(model)
# Get Target Modules
if not script_args.target_modules:
script_args.target_modules = get_target_modules(
model_name, script_args.finetuning_method
)
# Set fine-tuning config
if script_args.finetuning_method == "vera":
config = VeraConfig(
r=script_args.vera_rank,
target_modules=script_args.target_modules,
modules_to_save=get_classifier_modules(model_name),
vera_dropout=0.1,
bias="none",
)
elif script_args.finetuning_method == "boft":
config = BOFTConfig(
boft_block_size=script_args.boft_block_size,
boft_n_butterfly_factor=script_args.boft_n_butterfly_factor,
target_modules=script_args.target_modules,
modules_to_save=get_classifier_modules(model_name),
boft_dropout=0.1,
bias="boft_only",
)
elif script_args.finetuning_method == "lora":
config = LoraConfig(
r=script_args.lora_rank,
target_modules=script_args.target_modules,
modules_to_save=get_classifier_modules(model_name),
bias="none",
lora_dropout=0.1,
)
elif script_args.finetuning_method == "dora":
config = LoraConfig(
r=script_args.lora_rank,
target_modules=script_args.target_modules,
modules_to_save=get_classifier_modules(model_name),
bias="none",
lora_dropout=0.1,
use_dora=True,
)
elif script_args.finetuning_method == "head":
classifier_modules = get_classifier_modules(model_name)
for n, p in model.named_parameters():
if all(c not in n for c in classifier_modules):
p.requires_grad = False
elif script_args.finetuning_method in ["svft", "full"]:
pass
else:
raise ValueError("Unknown finetuning method")
if script_args.finetuning_method == "svft":
peft_model = model
modules_to_save_list = get_target_modules_list(
peft_model, get_classifier_modules(model_name)
)
freeze_model(peft_model, modules_to_save_list)
target_modules_list = get_target_modules_list(
peft_model, script_args.target_modules
)
create_and_replace_modules(peft_model, target_modules_list, partial(LinearWithSVFT, off_diag=script_args.svft_rank))
elif script_args.finetuning_method in ["head", "full"]:
peft_model = model
else:
peft_model = get_peft_model(model, config)
print_trainable_parameters(peft_model)
params_dict = get_trainable_params_dict(peft_model)
# Setup Trainer
args = TrainingArguments(**training_args.to_dict())
classifier_group = [
p
for n, p in model.named_parameters()
if p.requires_grad
and any(cls_name in n for cls_name in get_classifier_modules(model_name))
]
other_parameters_group = [
p
for n, p in model.named_parameters()
if p.requires_grad
and all(cls_name not in n for cls_name in get_classifier_modules(model_name))
]
optimizer = optim.AdamW(
[
{
"params": classifier_group,
"lr": script_args.clf_learning_rate,
},
{
"params": other_parameters_group,
"lr": script_args.other_learning_rate,
},
],
lr=script_args.other_learning_rate,
weight_decay=training_args.weight_decay,
)
num_train_steps = (
len(dataset_train)
// training_args.per_device_train_batch_size
* training_args.num_train_epochs
)
if training_args.lr_scheduler_type == "cosine":
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=int(num_train_steps * training_args.warmup_ratio),
num_training_steps=num_train_steps,
)
elif training_args.lr_scheduler_type == "linear":
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(num_train_steps * training_args.warmup_ratio),
num_training_steps=num_train_steps,
)
trainer = Trainer(
peft_model,
args,
optimizers=(optimizer, scheduler),
train_dataset=dataset_train,
eval_dataset=dataset_val,
tokenizer=image_processor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
train_results = trainer.train()
with open(
training_args.output_dir + f"final_train_results_{training_args.seed}.json", "w"
) as f:
json.dump(train_results, f, indent=4)
if script_args.finetuning_method == "svft":
create_and_replace_modules(peft_model, target_modules_list, reset_from_svft)
elif script_args.finetuning_method not in {"svft", "head", "full"}:
peft_model = peft_model.merge_and_unload()
eval_results = trainer.evaluate(dataset_test)
print(eval_results)
for key in script_args.__dataclass_fields__:
value = getattr(script_args, key)
eval_results[key] = value
for key in training_args.__dataclass_fields__:
if "accelerator" in key:
continue
value = getattr(training_args, key)
eval_results[key] = value
eval_results.update(params_dict)
pprint(eval_results, indent=4)
with open(
training_args.output_dir + f"final_eval_results_{training_args.seed}.json", "w"
) as f:
json.dump(eval_results, f, indent=4)
# Save to results.json
try:
with open(script_args.results_json, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = []
results.append(eval_results)
with open(script_args.results_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
main()