import gradio as gr import torch from PIL import Image from transformers import AutoImageProcessor, ViTForImageClassification from transformers import pipeline # CIFAR-10 Klassenlabels labels_cifar10 = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] # Lade Modell und Processor separat processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") model = ViTForImageClassification.from_pretrained("Fadri/results") # CLIP für Zero-Shot bleibt wie vorher clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification") def predict_cifar10(image_path): # Bild laden und vorverarbeiten image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt") # Modellvorhersage with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() # Top-3 Ergebnisse mit Wahrscheinlichkeiten probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] top3_probs, top3_indices = torch.topk(probabilities, 3) results = {} for idx, prob in zip(top3_indices, top3_probs): label = model.config.id2label[idx.item()] results[label] = round(prob.item(), 4) return results def classify_image(image): # Klassifikation mit deinem Modell cifar10_output = predict_cifar10(image) # Zero-Shot-Klassifikation mit CLIP clip_results = clip_detector(image, candidate_labels=labels_cifar10) clip_output = {result['label']: result['score'] for result in clip_results} return { "CIFAR-10 ViT Klassifikation": cifar10_output, "CLIP Zero-Shot Klassifikation": clip_output } # Beispielbilder (Pfade anpassen) example_images = [ ["examples/airplane.jpg"], ["examples/car.jpg"], ["examples/dog.jpg"], ["examples/cat.jpg"] ] # Gradio Interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="filepath"), outputs=gr.JSON(), title="CIFAR-10 Klassifikation", description="Lade ein Bild hoch und vergleiche die Ergebnisse zwischen deinem trainierten ViT Modell und CLIP.", examples=example_images ) iface.launch()