# app.py import torch, torch.nn as nn, torch.nn.functional as F from torchvision import transforms, models from PIL import Image import gradio as gr try: from huggingface_hub import hf_hub_download except ImportError: hf_hub_download = None # Your classes from categories import CATEGORIES # ----- CONFIG ----- MODEL_REPO = "MichaelMwb/butterfly-vgg16" # HF model repo (weights live here) MODEL_FILE = "model.pth" BACKBONE = "vgg16" # change to "resnet18" if needed DISPLAY_TEMPERATURE = 4.0 # try 4–10; higher = less “peaky” probs # Preprocess T = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ]) # ----- Model loading helpers ----- def build_backbone(num_classes: int): if BACKBONE == "vgg16": m = models.vgg16(weights=models.VGG16_Weights.DEFAULT) m.classifier[6] = nn.Linear(4096, num_classes) return m elif BACKBONE == "resnet18": m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) m.fc = nn.Linear(m.fc.in_features, num_classes) return m raise ValueError(f"Unsupported backbone: {BACKBONE}") def _torch_load(path): # torch>=2.6 adds weights_only kw; guard for both try: return torch.load(path, map_location="cpu", weights_only=False) except TypeError: return torch.load(path, map_location="cpu") def load_model_local(): obj = _torch_load(MODEL_FILE) if isinstance(obj, nn.Module): return obj.eval() m = build_backbone(len(CATEGORIES)) state = obj.get("state_dict", obj) m.load_state_dict(state, strict=False) return m.eval() def load_model_from_hub(): if hf_hub_download is None: raise RuntimeError("Install huggingface_hub or set MODEL_REPO='' to load local file.") ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE) obj = _torch_load(ckpt_path) if isinstance(obj, nn.Module): return obj.eval() m = build_backbone(len(CATEGORIES)) state = obj.get("state_dict", obj) m.load_state_dict(state, strict=False) return m.eval() # Global model + device _model = load_model_from_hub() if MODEL_REPO else load_model_local() # Start in AUTO mode: GPU if available, else CPU _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _model.to(_device) def set_device(choice: str): """Auto = cuda if available else cpu; CPU = cpu.""" global _device, _model if choice == "Auto": new_dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: new_dev = torch.device("cpu") _device = new_dev _model.to(_device) suffix = " (auto-selected)" if choice == "Auto" else "" return f"Using device: **{_device.type}**{suffix}" def predict(img: Image.Image, top_k: int = 5): if img is None: return {} x = T(img.convert("RGB")).unsqueeze(0).to(_device) with torch.no_grad(): logits = _model(x) # soften logits so softmax isn’t 100/0 probs = torch.softmax(logits / DISPLAY_TEMPERATURE, dim=1)[0].cpu() k = min(top_k, len(CATEGORIES)) confs, idxs = torch.topk(probs, k=k) return {CATEGORIES[i]: float(confs[j]) for j, i in enumerate(idxs)} # ---------- UI ---------- initial_status = f"Using device: **{_device.type}** (auto-selected)" BEST_RESULTS_MD = """ **Best results** - Single butterfly, centered, sharp focus - Plain/clean background; avoid heavy clutter - Show full wings; crop out large borders - Good lighting; avoid harsh shadows or filters - JPG/PNG around **800–1200 px** on the short side """ DISCLAIMER_MD = """ > **Disclaimer:** This demo uses a research model trained on a fixed dataset. > Predictions may be wrong—treat results as suggestions, not definitive IDs. > **Privacy:** Images are processed in memory and **not stored**. """ HOW_IT_WORKS_MD = """ ### How it works 1. Your image is resized and normalized (ImageNet stats). 2. A pretrained **{bb}** backbone with a small classifier head runs inference. 3. We show the **Top-k** species with confidence scores. ### Model & data - Backbone: **{bb}** (transfer learning) - Classes: **{n}** species - Inference: CPU by default; GPU if available (Auto) """.format(bb=BACKBONE.upper(), n=len(CATEGORIES)) with gr.Blocks(title="🦋 Butterfly Classifier", theme="gradio/soft") as demo: gr.Markdown("# 🦋 Butterfly Classifier\nUpload a photo to see the top-k predictions with confidences.") with gr.Row(): with gr.Column(scale=1): # Device: only Auto and CPU device_choice = gr.Radio(["Auto", "CPU"], value="Auto", label="Device") device_status = gr.Markdown(value=initial_status) # Short instructions block gr.Markdown(BEST_RESULTS_MD) # Uploader (upload only — webcam removed) img = gr.Image(type="pil", sources=["upload"], label="Upload image") # Disclaimer directly under uploader gr.Markdown(DISCLAIMER_MD) # Controls k = gr.Slider(1, 5, value=5, step=1, label="Top-k") btn = gr.Button("Predict", variant="primary") with gr.Column(scale=1): out = gr.Label(num_top_classes=5, label="Predictions") gr.Markdown(HOW_IT_WORKS_MD) device_choice.change(fn=set_device, inputs=device_choice, outputs=device_status) btn.click(predict, [img, k], out) if __name__ == "__main__": demo.launch()