import gradio as gr import torch import torch.nn as nn from torchvision import models from torchvision.models import ResNet34_Weights from PIL import Image import torchvision.transforms as transforms from huggingface_hub import hf_hub_download import os import random import glob # Import LoRA code from model import LoRALayer, apply_lora_to_model # Load model print("Loading model...") model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1) model.fc = nn.Linear(model.fc.in_features, 2) model = apply_lora_to_model(model, rank=8) # Load trained weights (from local Space files) model.load_state_dict(torch.load('best_model.pth', map_location='cpu')) model.eval() print("Model loaded successfully!") # Preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # Class names class_names = ['Non-Smoker', 'Smoker'] def predict(image): """ Predict if person in image is smoking Args: image: PIL Image Returns: dict: Prediction probabilities for each class """ if image is None: return None # Preprocess img_tensor = transform(image).unsqueeze(0) # Predict with torch.no_grad(): outputs = model(img_tensor) probabilities = torch.softmax(outputs, dim=1)[0] # Format results results = { class_names[i]: float(probabilities[i]) for i in range(len(class_names)) } return results # Get all example images example_images = glob.glob("All/*") examples = [[img] for img in example_images[:12]] # Takes the 12 images # Function to get random sample def get_random_sample(): """Load a random example image""" random_image_path = random.choice(example_images) return Image.open(random_image_path) # Create Gradio interface with custom CSS with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🚬 Smoker Detection Upload an image or try a random sample to detect if a person is smoking. This model uses **ResNet34 with LoRA fine-tuning** (only 2.14% of parameters trained) and achieves **89.73% test accuracy**. **Model:** [notrito/smoker-detection](https://huggingface.co/notrito/smoker-detection) """ ) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image") with gr.Row(): predict_btn = gr.Button("🔍 Predict", variant="primary") random_btn = gr.Button("🎲 Random Sample", variant="secondary") with gr.Column(): output_label = gr.Label(num_top_classes=2, label="Prediction") gr.Markdown("### 📸 Try these examples:") gr.Examples( examples=examples, inputs=input_image, outputs=output_label, fn=predict, cache_examples=True ) gr.Markdown( """ =================================================================================================== ### About this model - **Architecture:** ResNet34 + LoRA adapters (rank=8) - **Training:** Fine-tuned on 1,120 images - **Performance:** 89.73% test accuracy, 89.96% F1-score - **Efficiency:** Only 465K trainable parameters (2.14% of model) ### How it works LoRA (Low-Rank Adaptation) freezes the pretrained ImageNet weights and adds small trainable matrices to specific layers. This prevents overfitting on small datasets while maintaining the model's powerful feature extraction capabilities. ### Limitations - Trained on limited dataset (1,120 images) - Best for frontal/profile views with visible cigarettes - May not generalize to all smoking scenarios ### Links - [Model Card](https://huggingface.co/notrito/smoker-detection) - [Training Notebook](https://www.kaggle.com/code/notrito/smoker-detection-with-lora) **Author:** Noel Triguero """ ) # Connect buttons predict_btn.click(fn=predict, inputs=input_image, outputs=output_label) random_btn.click(fn=get_random_sample, inputs=None, outputs=input_image) if __name__ == "__main__": demo.launch()