Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoImageProcessor, ViTForImageClassification | |
from transformers import pipeline | |
# CIFAR-10 Klassenlabels | |
labels_cifar10 = [ | |
'airplane', 'automobile', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck' | |
] | |
# Lade Modell und Processor separat | |
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
model = ViTForImageClassification.from_pretrained("Fadri/results") | |
# CLIP für Zero-Shot bleibt wie vorher | |
clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification") | |
def predict_cifar10(image_path): | |
# Bild laden und vorverarbeiten | |
image = Image.open(image_path).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
# Modellvorhersage | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
# Top-3 Ergebnisse mit Wahrscheinlichkeiten | |
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
top3_probs, top3_indices = torch.topk(probabilities, 3) | |
results = {} | |
for idx, prob in zip(top3_indices, top3_probs): | |
label = model.config.id2label[idx.item()] | |
results[label] = round(prob.item(), 4) | |
return results | |
def classify_image(image): | |
# Klassifikation mit deinem Modell | |
cifar10_output = predict_cifar10(image) | |
# Zero-Shot-Klassifikation mit CLIP | |
clip_results = clip_detector(image, candidate_labels=labels_cifar10) | |
clip_output = {result['label']: result['score'] for result in clip_results} | |
return { | |
"CIFAR-10 ViT Klassifikation": cifar10_output, | |
"CLIP Zero-Shot Klassifikation": clip_output | |
} | |
# Beispielbilder (Pfade anpassen) | |
example_images = [ | |
["examples/airplane.jpg"], | |
["examples/car.jpg"], | |
["examples/dog.jpg"], | |
["examples/cat.jpg"] | |
] | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="filepath"), | |
outputs=gr.JSON(), | |
title="CIFAR-10 Klassifikation", | |
description="Lade ein Bild hoch und vergleiche die Ergebnisse zwischen deinem trainierten ViT Modell und CLIP.", | |
examples=example_images | |
) | |
iface.launch() |