Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
from datasets import load_dataset | |
from PIL import Image | |
import io | |
from tqdm import tqdm | |
from sklearn.metrics import accuracy_score, precision_score, recall_score | |
import os | |
# Clear the dataset cache | |
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets") | |
if os.path.exists(cache_dir): | |
import shutil | |
shutil.rmtree(cache_dir) | |
# Load the CLIP model for zero-shot classification | |
print("Loading CLIP model...") | |
checkpoint = "openai/clip-vit-large-patch14" | |
detector = pipeline(model=checkpoint, task="zero-shot-image-classification") | |
# Load the Oxford Pets dataset | |
print("Loading Oxford Pets dataset...") | |
try: | |
# Only use first 100 images for faster testing | |
dataset = load_dataset('pcuenq/oxford-pets', split='train[:100]') | |
print(f"Loaded {len(dataset)} images") | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
exit(1) | |
# Define the labels for Oxford Pets | |
labels_oxford_pets = [ | |
'Siamese', 'Birman', 'shiba inu', 'staffordshire bull terrier', 'basset hound', 'Bombay', 'japanese chin', | |
'chihuahua', 'german shorthaired', 'pomeranian', 'beagle', 'english cocker spaniel', 'american pit bull terrier', | |
'Ragdoll', 'Persian', 'Egyptian Mau', 'miniature pinscher', 'Sphynx', 'Maine Coon', 'keeshond', 'yorkshire terrier', | |
'havanese', 'leonberger', 'wheaten terrier', 'american bulldog', 'english setter', 'boxer', 'newfoundland', 'Bengal', | |
'samoyed', 'British Shorthair', 'great pyrenees', 'Abyssinian', 'pug', 'saint bernard', 'Russian Blue', 'scottish terrier' | |
] | |
# Lists to store true and predicted labels | |
true_labels = [] | |
predicted_labels = [] | |
print("Processing images...") | |
for i in tqdm(range(len(dataset)), desc="Processing images"): | |
try: | |
# Get the image bytes from the dataset | |
image_bytes = dataset[i]['image']['bytes'] | |
# Convert the bytes to a PIL image | |
image = Image.open(io.BytesIO(image_bytes)) | |
# Run the detector on the image with the provided labels | |
results = detector(image, candidate_labels=labels_oxford_pets) | |
# Sort the results by score in descending order | |
sorted_results = sorted(results, key=lambda x: x['score'], reverse=True) | |
# Get the top predicted label | |
predicted_label = sorted_results[0]['label'] | |
# Append the true and predicted labels to the respective lists | |
true_labels.append(dataset[i]['label']) | |
predicted_labels.append(predicted_label) | |
# Print progress every 10 images | |
if (i + 1) % 10 == 0: | |
print(f"Processed {i + 1}/{len(dataset)} images") | |
except Exception as e: | |
print(f"Error processing image {i}: {e}") | |
continue | |
# Calculate metrics | |
accuracy = accuracy_score(true_labels, predicted_labels) | |
precision = precision_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets) | |
recall = recall_score(true_labels, predicted_labels, average='weighted', labels=labels_oxford_pets) | |
# Print and save results | |
results = f""" | |
Zero-Shot Classification Results using CLIP (openai/clip-vit-large-patch14) | |
==================================================================== | |
Accuracy: {accuracy:.4f} | |
Precision: {precision:.4f} | |
Recall: {recall:.4f} | |
""" | |
print(results) | |
# Save results to a file | |
with open('zero_shot_results.md', 'w') as f: | |
f.write(results) |