Spaces:
Sleeping
Sleeping
File size: 4,588 Bytes
e786f60 0f00d59 e786f60 99808f1 e786f60 0f00d59 e786f60 45ee7ed e786f60 0f00d59 45ee7ed 0f00d59 45ee7ed 0f00d59 e786f60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet34_Weights
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
import os
import random
import glob
# Import LoRA code
from model import LoRALayer, apply_lora_to_model
# Load model
print("Loading model...")
model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model = apply_lora_to_model(model, rank=8)
# Load trained weights (from local Space files)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()
print("Model loaded successfully!")
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Class names
class_names = ['Non-Smoker', 'Smoker']
def predict(image):
"""
Predict if person in image is smoking
Args:
image: PIL Image
Returns:
dict: Prediction probabilities for each class
"""
if image is None:
return None
# Preprocess
img_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.softmax(outputs, dim=1)[0]
# Format results
results = {
class_names[i]: float(probabilities[i])
for i in range(len(class_names))
}
return results
# Get all example images
example_images = glob.glob("All/*")
examples = [[img] for img in example_images[:12]] # Takes the 12 images
# Function to get random sample
def get_random_sample():
"""Load a random example image"""
random_image_path = random.choice(example_images)
return Image.open(random_image_path)
# Create Gradio interface with custom CSS
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# π¬ Smoker Detection
Upload an image or try a random sample to detect if a person is smoking.
This model uses **ResNet34 with LoRA fine-tuning** (only 2.14% of parameters trained)
and achieves **89.73% test accuracy**.
**Model:** [notrito/smoker-detection](https://huggingface.co/notrito/smoker-detection)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
with gr.Row():
predict_btn = gr.Button("π Predict", variant="primary")
random_btn = gr.Button("π² Random Sample", variant="secondary")
with gr.Column():
output_label = gr.Label(num_top_classes=2, label="Prediction")
gr.Markdown("### πΈ Try these examples:")
gr.Examples(
examples=examples,
inputs=input_image,
outputs=output_label,
fn=predict,
cache_examples=True
)
gr.Markdown(
"""
===================================================================================================
### About this model
- **Architecture:** ResNet34 + LoRA adapters (rank=8)
- **Training:** Fine-tuned on 1,120 images
- **Performance:** 89.73% test accuracy, 89.96% F1-score
- **Efficiency:** Only 465K trainable parameters (2.14% of model)
### How it works
LoRA (Low-Rank Adaptation) freezes the pretrained ImageNet weights and adds small trainable
matrices to specific layers. This prevents overfitting on small datasets while maintaining
the model's powerful feature extraction capabilities.
### Limitations
- Trained on limited dataset (1,120 images)
- Best for frontal/profile views with visible cigarettes
- May not generalize to all smoking scenarios
### Links
- [Model Card](https://huggingface.co/notrito/smoker-detection)
- [Training Notebook](https://www.kaggle.com/code/notrito/smoker-detection-with-lora)
**Author:** Noel Triguero
"""
)
# Connect buttons
predict_btn.click(fn=predict, inputs=input_image, outputs=output_label)
random_btn.click(fn=get_random_sample, inputs=None, outputs=input_image)
if __name__ == "__main__":
demo.launch() |