Spaces:
Sleeping
Sleeping
File size: 5,599 Bytes
d503e88 dff227d d503e88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# app.py
import torch, torch.nn as nn, torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import gradio as gr
try:
from huggingface_hub import hf_hub_download
except ImportError:
hf_hub_download = None
# Your classes
from categories import CATEGORIES
# ----- CONFIG -----
MODEL_REPO = "MichaelMwb/butterfly-vgg16" # HF model repo (weights live here)
MODEL_FILE = "model.pth"
BACKBONE = "vgg16" # change to "resnet18" if needed
DISPLAY_TEMPERATURE = 4.0 # try 4–10; higher = less “peaky” probs
# Preprocess
T = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
# ----- Model loading helpers -----
def build_backbone(num_classes: int):
if BACKBONE == "vgg16":
m = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
m.classifier[6] = nn.Linear(4096, num_classes)
return m
elif BACKBONE == "resnet18":
m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
m.fc = nn.Linear(m.fc.in_features, num_classes)
return m
raise ValueError(f"Unsupported backbone: {BACKBONE}")
def _torch_load(path):
# torch>=2.6 adds weights_only kw; guard for both
try:
return torch.load(path, map_location="cpu", weights_only=False)
except TypeError:
return torch.load(path, map_location="cpu")
def load_model_local():
obj = _torch_load(MODEL_FILE)
if isinstance(obj, nn.Module):
return obj.eval()
m = build_backbone(len(CATEGORIES))
state = obj.get("state_dict", obj)
m.load_state_dict(state, strict=False)
return m.eval()
def load_model_from_hub():
if hf_hub_download is None:
raise RuntimeError("Install huggingface_hub or set MODEL_REPO='' to load local file.")
ckpt_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
obj = _torch_load(ckpt_path)
if isinstance(obj, nn.Module):
return obj.eval()
m = build_backbone(len(CATEGORIES))
state = obj.get("state_dict", obj)
m.load_state_dict(state, strict=False)
return m.eval()
# Global model + device
_model = load_model_from_hub() if MODEL_REPO else load_model_local()
# Start in AUTO mode: GPU if available, else CPU
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_model.to(_device)
def set_device(choice: str):
"""Auto = cuda if available else cpu; CPU = cpu."""
global _device, _model
if choice == "Auto":
new_dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
new_dev = torch.device("cpu")
_device = new_dev
_model.to(_device)
suffix = " (auto-selected)" if choice == "Auto" else ""
return f"Using device: **{_device.type}**{suffix}"
def predict(img: Image.Image, top_k: int = 5):
if img is None:
return {}
x = T(img.convert("RGB")).unsqueeze(0).to(_device)
with torch.no_grad():
logits = _model(x)
# soften logits so softmax isn’t 100/0
probs = torch.softmax(logits / DISPLAY_TEMPERATURE, dim=1)[0].cpu()
k = min(top_k, len(CATEGORIES))
confs, idxs = torch.topk(probs, k=k)
return {CATEGORIES[i]: float(confs[j]) for j, i in enumerate(idxs)}
# ---------- UI ----------
initial_status = f"Using device: **{_device.type}** (auto-selected)"
BEST_RESULTS_MD = """
**Best results**
- Single butterfly, centered, sharp focus
- Plain/clean background; avoid heavy clutter
- Show full wings; crop out large borders
- Good lighting; avoid harsh shadows or filters
- JPG/PNG around **800–1200 px** on the short side
"""
DISCLAIMER_MD = """
> **Disclaimer:** This demo uses a research model trained on a fixed dataset.
> Predictions may be wrong—treat results as suggestions, not definitive IDs.
> **Privacy:** Images are processed in memory and **not stored**.
"""
HOW_IT_WORKS_MD = """
### How it works
1. Your image is resized and normalized (ImageNet stats).
2. A pretrained **{bb}** backbone with a small classifier head runs inference.
3. We show the **Top-k** species with confidence scores.
### Model & data
- Backbone: **{bb}** (transfer learning)
- Classes: **{n}** species
- Inference: CPU by default; GPU if available (Auto)
""".format(bb=BACKBONE.upper(), n=len(CATEGORIES))
with gr.Blocks(title="🦋 Butterfly Classifier", theme="gradio/soft") as demo:
gr.Markdown("# 🦋 Butterfly Classifier\nUpload a photo to see the top-k predictions with confidences.")
with gr.Row():
with gr.Column(scale=1):
# Device: only Auto and CPU
device_choice = gr.Radio(["Auto", "CPU"], value="Auto", label="Device")
device_status = gr.Markdown(value=initial_status)
# Short instructions block
gr.Markdown(BEST_RESULTS_MD)
# Uploader (upload only — webcam removed)
img = gr.Image(type="pil", sources=["upload"], label="Upload image")
# Disclaimer directly under uploader
gr.Markdown(DISCLAIMER_MD)
# Controls
k = gr.Slider(1, 5, value=5, step=1, label="Top-k")
btn = gr.Button("Predict", variant="primary")
with gr.Column(scale=1):
out = gr.Label(num_top_classes=5, label="Predictions")
gr.Markdown(HOW_IT_WORKS_MD)
device_choice.change(fn=set_device, inputs=device_choice, outputs=device_status)
btn.click(predict, [img, k], out)
if __name__ == "__main__":
demo.launch()
|