Block_Computer_Vision / zero_shot_classification.py
Monyrak's picture
Upload zero_shot_classification.py
dafa0bc verified
from transformers import pipeline
from datasets import load_dataset
from PIL import Image
import io
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score
import os
# Clear the dataset cache
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
if os.path.exists(cache_dir):
import shutil
shutil.rmtree(cache_dir)
# Load the CLIP model for zero-shot classification
print("Loading CLIP model...")
checkpoint = "openai/clip-vit-large-patch14"
detector = pipeline(model=checkpoint, task="zero-shot-image-classification")
# Load the Oxford Pets dataset
print("Loading Oxford Pets dataset...")
try:
# Only use first 100 images for faster testing
dataset = load_dataset('pcuenq/oxford-pets', split='train[:100]')
print(f"Loaded {len(dataset)} images")
except Exception as e:
print(f"Error loading dataset: {e}")
exit(1)
# Define the labels for Oxford Pets
labels_oxford_pets = [
'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin',
'chihuahua', 'german shorthaired', 'pomeranian', 'beagle', 'english cocker spaniel', 'american pit bull terrier',
'Ragdoll', 'Persian', 'Egyptian Mau', 'miniature pinscher', 'Sphynx', 'Maine Coon', 'keeshond', 'yorkshire terrier',
'havanese', 'leonberger', 'wheaten terrier', 'american bulldog', 'english setter', 'boxer', 'newfoundland', 'Bengal',
'samoyed', 'British Shorthair', 'great pyrenees', 'Abyssinian', 'pug', 'saint bernard', 'Russian Blue', 'scottish terrier'
]
# Lists to store true and predicted labels
true_labels = []
predicted_labels = []
print("Processing images...")
for i in tqdm(range(len(dataset)), desc="Processing images"):
try:
# Get the image bytes from the dataset
image_bytes = dataset[i]['image']['bytes']
# Convert the bytes to a PIL image
image = Image.open(io.BytesIO(image_bytes))
# Run the detector on the image with the provided labels
results = detector(image, candidate_labels=labels_oxford_pets)
# Sort the results by score in descending order
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
# Get the top predicted label
predicted_label = sorted_results[0]['label']
# Append the true and predicted labels to the respective lists
true_labels.append(dataset[i]['label'])
predicted_labels.append(predicted_label)
# Print progress every 10 images
if (i + 1) % 10 == 0:
print(f"Processed {i + 1}/{len(dataset)} images")
except Exception as e:
print(f"Error processing image {i}: {e}")
continue
# Calculate metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
recall = recall_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets)
# Print and save results
results = f"""
Zero-Shot Classification Results using CLIP (openai/clip-vit-large-patch14)
====================================================================
Accuracy: {accuracy:.4f}
Precision: {precision:.4f}
Recall: {recall:.4f}
"""
print(results)
# Save results to a file
with open('zero_shot_results.md', 'w') as f:
f.write(results)