Skorm's picture
Update label for zeroshot model output
e1afdf8 verified
import gradio as gr
from transformers import pipeline, CLIPProcessor, CLIPModel
from PIL import Image
import torch
classifier = pipeline("image-classification", model="Skorm/food11-vit")
# Load CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# Define CLIP labels
clip_labels = [
"Bread", "Dairy product", "Dessert", "Egg", "Fried food",
"Meat", "Noodles-Pasta", "Rice", "Seafood", "Soup", "Vegetable-Fruit"
]
def classify_food(image_path):
image = Image.open(image_path)
# ----- ViT prediction -----
vit_results = classifier(image_path)
vit_output = {result["label"]: round(result["score"], 4) for result in vit_results}
# ----- CLIP zero-shot prediction -----
inputs = clip_processor(text=clip_labels, images=image, return_tensors="pt", padding=True)
outputs = clip_model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)[0]
clip_output = {label: round(float(score), 4) for label, score in zip(clip_labels, probs)}
return vit_output, clip_output
# Example image paths
examples = [
["example_images/bread.jpg"],
["example_images/dessert.jpg"],
["example_images/fruits.jpg"],
["example_images/noodles.jpeg"],
["example_images/ramen.jpg"],
["example_images/seafood.jpg"],
]
# Gradio interface
iface = gr.Interface(
fn=classify_food,
inputs=gr.Image(type="filepath"),
outputs=[
gr.Label(num_top_classes=3, label="ViT (Fine-tuned) Prediction"),
gr.Label(num_top_classes=3, label="CLIP Zero-Shot Prediction")
],
title="🍽️ Food Classification with ViT and Zero-Shot CLIP",
description="Upload a food image. The app compares predictions between your fine-tuned ViT model and zero-shot CLIP.",
examples=examples
)
iface.launch()