File size: 1,161 Bytes
c66d381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from transformers import AutoProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
import torch

# Load dataset from the 'dataset' folder
dataset = load_dataset("imagefolder", data_dir="dataset", split="train", label_column="label")

# Load model and processor
model = AutoModelForImageClassification.from_pretrained("google/siglip2-base-patch16-naflex", num_labels=2)
processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-naflex")

# Preprocess the dataset
def transform(example):
    inputs = processor(images=example["image"], return_tensors="pt")
    inputs["label"] = example["label"]
    return inputs

dataset = dataset.map(transform, batched=True)

# Training setup
training_args = TrainingArguments(
    output_dir="./siglip2-meme-classifier",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    save_steps=100,
    logging_dir="./logs",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# Start training
trainer.train()

# Save the fine-tuned model and processor
model.save_pretrained("model")
processor.save_pretrained("model")