File size: 2,296 Bytes
619cee0
2f41bd6
 
 
619cee0
 
2f41bd6
619cee0
 
 
 
 
2f41bd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619cee0
2f41bd6
 
619cee0
 
 
 
 
 
 
 
 
 
2f41bd6
619cee0
 
 
 
2f41bd6
619cee0
 
 
 
 
 
 
 
2f41bd6
619cee0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import torch
from PIL import Image
from transformers import AutoImageProcessor, ViTForImageClassification
from transformers import pipeline

# CIFAR-10 Klassenlabels
labels_cifar10 = [
    'airplane', 'automobile', 'bird', 'cat', 'deer', 
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Lade Modell und Processor separat
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained("Fadri/results")

# CLIP für Zero-Shot bleibt wie vorher
clip_detector = pipeline(model="openai/clip-vit-large-patch14", task="zero-shot-image-classification")

def predict_cifar10(image_path):
    # Bild laden und vorverarbeiten
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    
    # Modellvorhersage
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
    
    # Top-3 Ergebnisse mit Wahrscheinlichkeiten
    probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
    top3_probs, top3_indices = torch.topk(probabilities, 3)
    
    results = {}
    for idx, prob in zip(top3_indices, top3_probs):
        label = model.config.id2label[idx.item()]
        results[label] = round(prob.item(), 4)
    
    return results

def classify_image(image):
    # Klassifikation mit deinem Modell
    cifar10_output = predict_cifar10(image)
    
    # Zero-Shot-Klassifikation mit CLIP
    clip_results = clip_detector(image, candidate_labels=labels_cifar10)
    clip_output = {result['label']: result['score'] for result in clip_results}
    
    return {
        "CIFAR-10 ViT Klassifikation": cifar10_output, 
        "CLIP Zero-Shot Klassifikation": clip_output
    }

# Beispielbilder (Pfade anpassen)
example_images = [
    ["examples/airplane.jpg"],
    ["examples/car.jpg"],
    ["examples/dog.jpg"],
    ["examples/cat.jpg"]
]

# Gradio Interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="filepath"),
    outputs=gr.JSON(),
    title="CIFAR-10 Klassifikation",
    description="Lade ein Bild hoch und vergleiche die Ergebnisse zwischen deinem trainierten ViT Modell und CLIP.",
    examples=example_images
)

iface.launch()