Spaces:
Sleeping
Sleeping
from transformers import AutoProcessor, AutoModelForImageClassification, TrainingArguments, Trainer | |
from datasets import load_dataset | |
import torch | |
# Load dataset from the 'dataset' folder | |
dataset = load_dataset("imagefolder", data_dir="dataset", split="train", label_column="label") | |
# Load model and processor | |
model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex", num_labels=2) | |
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex") | |
# Preprocess the dataset | |
def transform(example): | |
inputs = processor(images=example["image"], return_tensors="pt") | |
inputs["label"] = example["label"] | |
return inputs | |
dataset = dataset.map(transform, batched=True) | |
# Training setup | |
training_args = TrainingArguments( | |
output_dir="./siglip2-meme-classifier", | |
per_device_train_batch_size=8, | |
num_train_epochs=3, | |
save_steps=100, | |
logging_dir="./logs", | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
) | |
# Start training | |
trainer.train() | |
# Save the fine-tuned model and processor | |
model.save_pretrained("model") | |
processor.save_pretrained("model") | |