Thesis_CLIP / train.py
Chanlefe's picture
Create train.py
c66d381 verified
from transformers import AutoProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
import torch
# Load dataset from the 'dataset' folder
dataset = load_dataset("imagefolder", data_dir="dataset", split="train", label_column="label")
# Load model and processor
model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex", num_labels=2)
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")
# Preprocess the dataset
def transform(example):
inputs = processor(images=example["image"], return_tensors="pt")
inputs["label"] = example["label"]
return inputs
dataset = dataset.map(transform, batched=True)
# Training setup
training_args = TrainingArguments(
output_dir="./siglip2-meme-classifier",
per_device_train_batch_size=8,
num_train_epochs=3,
save_steps=100,
logging_dir="./logs",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)
# Start training
trainer.train()
# Save the fine-tuned model and processor
model.save_pretrained("model")
processor.save_pretrained("model")