MichaelMwb's picture
Update app.py
dff227d verified
# app.py
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import gradio as gr
try:
from huggingface_hub import hf_hub_download
except ImportError:
hf_hub_download = None
# Your classes
from categories import CATEGORIES
# ----- CONFIG -----
MODEL_REPO = "MichaelMwb/butterfly-vgg16" # HF model repo (weights live here)
MODEL_FILE = "model.pth"
BACKBONE = "vgg16" # change to "resnet18" if needed
DISPLAY_TEMPERATURE = 4.0 # try 4–10; higher = less “peaky” probs
# Preprocess
T = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
# ----- Model loading helpers -----
def build_backbone(num_classes: int):
if BACKBONE == "vgg16":
m = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
m.classifier[6] = nn.Linear(4096, num_classes)
return m
elif BACKBONE == "resnet18":
m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
m.fc = nn.Linear(m.fc.in_features, num_classes)
return m
raise ValueError(f"Unsupported backbone: {BACKBONE}")
def _torch_load(path):
# torch>=2.6 adds weights_only kw; guard for both
try:
return torch.load(path, map_location="cpu", weights_only=False)
except TypeError:
return torch.load(path, map_location="cpu")
def load_model_local():
obj = _torch_load(MODEL_FILE)
if isinstance(obj, nn.Module):
return obj.eval()
m = build_backbone(len(CATEGORIES))
state = obj.get("state_dict", obj)
m.load_state_dict(state, strict=False)
return m.eval()
def load_model_from_hub():
if hf_hub_download is None:
raise RuntimeError("Install huggingface_hub or set MODEL_REPO='' to load local file.")
ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
obj = _torch_load(ckpt_path)
if isinstance(obj, nn.Module):
return obj.eval()
m = build_backbone(len(CATEGORIES))
state = obj.get("state_dict", obj)
m.load_state_dict(state, strict=False)
return m.eval()
# Global model + device
_model = load_model_from_hub() if MODEL_REPO else load_model_local()
# Start in AUTO mode: GPU if available, else CPU
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model.to(_device)
def set_device(choice: str):
"""Auto = cuda if available else cpu; CPU = cpu."""
global _device, _model
if choice == "Auto":
new_dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
new_dev = torch.device("cpu")
_device = new_dev
_model.to(_device)
suffix = " (auto-selected)" if choice == "Auto" else ""
return f"Using device: **{_device.type}**{suffix}"
def predict(img: Image.Image, top_k: int = 5):
if img is None:
return {}
x = T(img.convert("RGB")).unsqueeze(0).to(_device)
with torch.no_grad():
logits = _model(x)
# soften logits so softmax isn’t 100/0
probs = torch.softmax(logits / DISPLAY_TEMPERATURE, dim=1)[0].cpu()
k = min(top_k, len(CATEGORIES))
confs, idxs = torch.topk(probs, k=k)
return {CATEGORIES[i]: float(confs[j]) for j, i in enumerate(idxs)}
# ---------- UI ----------
initial_status = f"Using device: **{_device.type}** (auto-selected)"
BEST_RESULTS_MD = """
**Best results**
- Single butterfly, centered, sharp focus
- Plain/clean background; avoid heavy clutter
- Show full wings; crop out large borders
- Good lighting; avoid harsh shadows or filters
- JPG/PNG around **800–1200 px** on the short side
"""
DISCLAIMER_MD = """
> **Disclaimer:** This demo uses a research model trained on a fixed dataset.
> Predictions may be wrong—treat results as suggestions, not definitive IDs.
> **Privacy:** Images are processed in memory and **not stored**.
"""
HOW_IT_WORKS_MD = """
### How it works
1. Your image is resized and normalized (ImageNet stats).
2. A pretrained **{bb}** backbone with a small classifier head runs inference.
3. We show the **Top-k** species with confidence scores.
### Model & data
- Backbone: **{bb}** (transfer learning)
- Classes: **{n}** species
- Inference: CPU by default; GPU if available (Auto)
""".format(bb=BACKBONE.upper(), n=len(CATEGORIES))
with gr.Blocks(title="🦋 Butterfly Classifier", theme="gradio/soft") as demo:
gr.Markdown("# 🦋 Butterfly Classifier\nUpload a photo to see the top-k predictions with confidences.")
with gr.Row():
with gr.Column(scale=1):
# Device: only Auto and CPU
device_choice = gr.Radio(["Auto", "CPU"], value="Auto", label="Device")
device_status = gr.Markdown(value=initial_status)
# Short instructions block
gr.Markdown(BEST_RESULTS_MD)
# Uploader (upload only — webcam removed)
img = gr.Image(type="pil", sources=["upload"], label="Upload image")
# Disclaimer directly under uploader
gr.Markdown(DISCLAIMER_MD)
# Controls
k = gr.Slider(1, 5, value=5, step=1, label="Top-k")
btn = gr.Button("Predict", variant="primary")
with gr.Column(scale=1):
out = gr.Label(num_top_classes=5, label="Predictions")
gr.Markdown(HOW_IT_WORKS_MD)
device_choice.change(fn=set_device, inputs=device_choice, outputs=device_status)
btn.click(predict, [img, k], out)
if __name__ == "__main__":
demo.launch()