File size: 5,599 Bytes
d503e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dff227d
d503e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# app.py
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import gradio as gr

try:
    from huggingface_hub import hf_hub_download
except ImportError:
    hf_hub_download = None

# Your classes
from categories import CATEGORIES

# ----- CONFIG -----
MODEL_REPO = "MichaelMwb/butterfly-vgg16"   # HF model repo (weights live here)
MODEL_FILE = "model.pth"
BACKBONE   = "vgg16"                        # change to "resnet18" if needed
DISPLAY_TEMPERATURE = 4.0   # try 4–10; higher = less “peaky” probs


# Preprocess
T = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

# ----- Model loading helpers -----
def build_backbone(num_classes: int):
    if BACKBONE == "vgg16":
        m = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        m.classifier[6] = nn.Linear(4096, num_classes)
        return m
    elif BACKBONE == "resnet18":
        m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    raise ValueError(f"Unsupported backbone: {BACKBONE}")

def _torch_load(path):
    # torch>=2.6 adds weights_only kw; guard for both
    try:
        return torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(path, map_location="cpu")

def load_model_local():
    obj = _torch_load(MODEL_FILE)
    if isinstance(obj, nn.Module):
        return obj.eval()
    m = build_backbone(len(CATEGORIES))
    state = obj.get("state_dict", obj)
    m.load_state_dict(state, strict=False)
    return m.eval()

def load_model_from_hub():
    if hf_hub_download is None:
        raise RuntimeError("Install huggingface_hub or set MODEL_REPO='' to load local file.")
    ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
    obj = _torch_load(ckpt_path)
    if isinstance(obj, nn.Module):
        return obj.eval()
    m = build_backbone(len(CATEGORIES))
    state = obj.get("state_dict", obj)
    m.load_state_dict(state, strict=False)
    return m.eval()

# Global model + device
_model = load_model_from_hub() if MODEL_REPO else load_model_local()
# Start in AUTO mode: GPU if available, else CPU
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model.to(_device)

def set_device(choice: str):
    """Auto = cuda if available else cpu; CPU = cpu."""
    global _device, _model
    if choice == "Auto":
        new_dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        new_dev = torch.device("cpu")
    _device = new_dev
    _model.to(_device)
    suffix = " (auto-selected)" if choice == "Auto" else ""
    return f"Using device: **{_device.type}**{suffix}"

def predict(img: Image.Image, top_k: int = 5):
    if img is None:
        return {}
    x = T(img.convert("RGB")).unsqueeze(0).to(_device)
    with torch.no_grad():
        logits = _model(x)
        # soften logits so softmax isn’t 100/0
        probs = torch.softmax(logits / DISPLAY_TEMPERATURE, dim=1)[0].cpu()

    k = min(top_k, len(CATEGORIES))
    confs, idxs = torch.topk(probs, k=k)
    return {CATEGORIES[i]: float(confs[j]) for j, i in enumerate(idxs)}


# ---------- UI ----------
initial_status = f"Using device: **{_device.type}** (auto-selected)"

BEST_RESULTS_MD = """
**Best results**
- Single butterfly, centered, sharp focus  
- Plain/clean background; avoid heavy clutter  
- Show full wings; crop out large borders  
- Good lighting; avoid harsh shadows or filters  
- JPG/PNG around **800–1200 px** on the short side
"""

DISCLAIMER_MD = """
> **Disclaimer:** This demo uses a research model trained on a fixed dataset.  
> Predictions may be wrong—treat results as suggestions, not definitive IDs.  
> **Privacy:** Images are processed in memory and **not stored**.
"""

HOW_IT_WORKS_MD = """
### How it works
1. Your image is resized and normalized (ImageNet stats).  
2. A pretrained **{bb}** backbone with a small classifier head runs inference.  
3. We show the **Top-k** species with confidence scores.

### Model & data
- Backbone: **{bb}** (transfer learning)
- Classes: **{n}** species
- Inference: CPU by default; GPU if available (Auto)
""".format(bb=BACKBONE.upper(), n=len(CATEGORIES))

with gr.Blocks(title="🦋 Butterfly Classifier", theme="gradio/soft") as demo:
    gr.Markdown("# 🦋 Butterfly Classifier\nUpload a photo to see the top-k predictions with confidences.")

    with gr.Row():
        with gr.Column(scale=1):
            # Device: only Auto and CPU
            device_choice = gr.Radio(["Auto", "CPU"], value="Auto", label="Device")
            device_status = gr.Markdown(value=initial_status)

            # Short instructions block
            gr.Markdown(BEST_RESULTS_MD)

            # Uploader (upload only — webcam removed)
            img = gr.Image(type="pil", sources=["upload"], label="Upload image")

            # Disclaimer directly under uploader
            gr.Markdown(DISCLAIMER_MD)

            # Controls
            k = gr.Slider(1, 5, value=5, step=1, label="Top-k")
            btn = gr.Button("Predict", variant="primary")

        with gr.Column(scale=1):
            out = gr.Label(num_top_classes=5, label="Predictions")
            gr.Markdown(HOW_IT_WORKS_MD)

    device_choice.change(fn=set_device, inputs=device_choice, outputs=device_status)
    btn.click(predict, [img, k], out)

if __name__ == "__main__":
    demo.launch()