Spaces:
Sleeping
Sleeping
# app.py | |
import torch, torch.nn as nn, torch.nn.functional as F | |
from torchvision import transforms, models | |
from PIL import Image | |
import gradio as gr | |
try: | |
from huggingface_hub import hf_hub_download | |
except ImportError: | |
hf_hub_download = None | |
# Your classes | |
from categories import CATEGORIES | |
# ----- CONFIG ----- | |
MODEL_REPO = "MichaelMwb/butterfly-vgg16" # HF model repo (weights live here) | |
MODEL_FILE = "model.pth" | |
BACKBONE = "vgg16" # change to "resnet18" if needed | |
DISPLAY_TEMPERATURE = 4.0 # try 4–10; higher = less “peaky” probs | |
# Preprocess | |
T = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), | |
]) | |
# ----- Model loading helpers ----- | |
def build_backbone(num_classes: int): | |
if BACKBONE == "vgg16": | |
m = models.vgg16(weights=models.VGG16_Weights.DEFAULT) | |
m.classifier[6] = nn.Linear(4096, num_classes) | |
return m | |
elif BACKBONE == "resnet18": | |
m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) | |
m.fc = nn.Linear(m.fc.in_features, num_classes) | |
return m | |
raise ValueError(f"Unsupported backbone: {BACKBONE}") | |
def _torch_load(path): | |
# torch>=2.6 adds weights_only kw; guard for both | |
try: | |
return torch.load(path, map_location="cpu", weights_only=False) | |
except TypeError: | |
return torch.load(path, map_location="cpu") | |
def load_model_local(): | |
obj = _torch_load(MODEL_FILE) | |
if isinstance(obj, nn.Module): | |
return obj.eval() | |
m = build_backbone(len(CATEGORIES)) | |
state = obj.get("state_dict", obj) | |
m.load_state_dict(state, strict=False) | |
return m.eval() | |
def load_model_from_hub(): | |
if hf_hub_download is None: | |
raise RuntimeError("Install huggingface_hub or set MODEL_REPO='' to load local file.") | |
ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) | |
obj = _torch_load(ckpt_path) | |
if isinstance(obj, nn.Module): | |
return obj.eval() | |
m = build_backbone(len(CATEGORIES)) | |
state = obj.get("state_dict", obj) | |
m.load_state_dict(state, strict=False) | |
return m.eval() | |
# Global model + device | |
_model = load_model_from_hub() if MODEL_REPO else load_model_local() | |
# Start in AUTO mode: GPU if available, else CPU | |
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
_model.to(_device) | |
def set_device(choice: str): | |
"""Auto = cuda if available else cpu; CPU = cpu.""" | |
global _device, _model | |
if choice == "Auto": | |
new_dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
new_dev = torch.device("cpu") | |
_device = new_dev | |
_model.to(_device) | |
suffix = " (auto-selected)" if choice == "Auto" else "" | |
return f"Using device: **{_device.type}**{suffix}" | |
def predict(img: Image.Image, top_k: int = 5): | |
if img is None: | |
return {} | |
x = T(img.convert("RGB")).unsqueeze(0).to(_device) | |
with torch.no_grad(): | |
logits = _model(x) | |
# soften logits so softmax isn’t 100/0 | |
probs = torch.softmax(logits / DISPLAY_TEMPERATURE, dim=1)[0].cpu() | |
k = min(top_k, len(CATEGORIES)) | |
confs, idxs = torch.topk(probs, k=k) | |
return {CATEGORIES[i]: float(confs[j]) for j, i in enumerate(idxs)} | |
# ---------- UI ---------- | |
initial_status = f"Using device: **{_device.type}** (auto-selected)" | |
BEST_RESULTS_MD = """ | |
**Best results** | |
- Single butterfly, centered, sharp focus | |
- Plain/clean background; avoid heavy clutter | |
- Show full wings; crop out large borders | |
- Good lighting; avoid harsh shadows or filters | |
- JPG/PNG around **800–1200 px** on the short side | |
""" | |
DISCLAIMER_MD = """ | |
> **Disclaimer:** This demo uses a research model trained on a fixed dataset. | |
> Predictions may be wrong—treat results as suggestions, not definitive IDs. | |
> **Privacy:** Images are processed in memory and **not stored**. | |
""" | |
HOW_IT_WORKS_MD = """ | |
### How it works | |
1. Your image is resized and normalized (ImageNet stats). | |
2. A pretrained **{bb}** backbone with a small classifier head runs inference. | |
3. We show the **Top-k** species with confidence scores. | |
### Model & data | |
- Backbone: **{bb}** (transfer learning) | |
- Classes: **{n}** species | |
- Inference: CPU by default; GPU if available (Auto) | |
""".format(bb=BACKBONE.upper(), n=len(CATEGORIES)) | |
with gr.Blocks(title="🦋 Butterfly Classifier", theme="gradio/soft") as demo: | |
gr.Markdown("# 🦋 Butterfly Classifier\nUpload a photo to see the top-k predictions with confidences.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Device: only Auto and CPU | |
device_choice = gr.Radio(["Auto", "CPU"], value="Auto", label="Device") | |
device_status = gr.Markdown(value=initial_status) | |
# Short instructions block | |
gr.Markdown(BEST_RESULTS_MD) | |
# Uploader (upload only — webcam removed) | |
img = gr.Image(type="pil", sources=["upload"], label="Upload image") | |
# Disclaimer directly under uploader | |
gr.Markdown(DISCLAIMER_MD) | |
# Controls | |
k = gr.Slider(1, 5, value=5, step=1, label="Top-k") | |
btn = gr.Button("Predict", variant="primary") | |
with gr.Column(scale=1): | |
out = gr.Label(num_top_classes=5, label="Predictions") | |
gr.Markdown(HOW_IT_WORKS_MD) | |
device_choice.change(fn=set_device, inputs=device_choice, outputs=device_status) | |
btn.click(predict, [img, k], out) | |
if __name__ == "__main__": | |
demo.launch() | |