from transformers import AutoProcessor, AutoModelForImageClassification, TrainingArguments, Trainer from datasets import load_dataset import torch # Load dataset from the 'dataset' folder dataset = load_dataset("imagefolder", data_dir="dataset", split="train", label_column="label") # Load model and processor model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex", num_labels=2) processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex") # Preprocess the dataset def transform(example): inputs = processor(images=example["image"], return_tensors="pt") inputs["label"] = example["label"] return inputs dataset = dataset.map(transform, batched=True) # Training setup training_args = TrainingArguments( output_dir="./siglip2-meme-classifier", per_device_train_batch_size=8, num_train_epochs=3, save_steps=100, logging_dir="./logs", ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, ) # Start training trainer.train() # Save the fine-tuned model and processor model.save_pretrained("model") processor.save_pretrained("model")