File size: 4,588 Bytes
e786f60
 
 
 
 
 
 
 
0f00d59
 
 
e786f60
 
 
 
 
 
 
 
 
 
99808f1
 
e786f60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f00d59
 
e786f60
45ee7ed
 
e786f60
0f00d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45ee7ed
 
 
 
 
 
 
 
0f00d59
 
 
45ee7ed
 
0f00d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e786f60
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import ResNet34_Weights
from PIL import Image
import torchvision.transforms as transforms
from huggingface_hub import hf_hub_download
import os
import random
import glob

# Import LoRA code
from model import LoRALayer, apply_lora_to_model

# Load model
print("Loading model...")
model = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model = apply_lora_to_model(model, rank=8)

# Load trained weights (from local Space files)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()
print("Model loaded successfully!")

# Preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Class names
class_names = ['Non-Smoker', 'Smoker']

def predict(image):
    """
    Predict if person in image is smoking
    
    Args:
        image: PIL Image
        
    Returns:
        dict: Prediction probabilities for each class
    """
    if image is None:
        return None
    
    # Preprocess
    img_tensor = transform(image).unsqueeze(0)
    
    # Predict
    with torch.no_grad():
        outputs = model(img_tensor)
        probabilities = torch.softmax(outputs, dim=1)[0]
    
    # Format results
    results = {
        class_names[i]: float(probabilities[i]) 
        for i in range(len(class_names))
    }
    
    return results


# Get all example images
example_images = glob.glob("All/*")

examples = [[img] for img in example_images[:12]]  # Takes the 12 images


# Function to get random sample
def get_random_sample():
    """Load a random example image"""
    random_image_path = random.choice(example_images)
    return Image.open(random_image_path)
    
# Create Gradio interface with custom CSS
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🚬 Smoker Detection
        
        Upload an image or try a random sample to detect if a person is smoking.
        
        This model uses **ResNet34 with LoRA fine-tuning** (only 2.14% of parameters trained) 
        and achieves **89.73% test accuracy**.
        
        **Model:** [notrito/smoker-detection](https://huggingface.co/notrito/smoker-detection)
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload Image")
            
            with gr.Row():
                predict_btn = gr.Button("πŸ” Predict", variant="primary")
                random_btn = gr.Button("🎲 Random Sample", variant="secondary")
            
        with gr.Column():
            output_label = gr.Label(num_top_classes=2, label="Prediction")
    
            gr.Markdown("### πŸ“Έ Try these examples:")
            gr.Examples(
                examples=examples,
                inputs=input_image,
                outputs=output_label,
                fn=predict,
                cache_examples=True
            )
    
    gr.Markdown(
        """
        ===================================================================================================
        
        ### About this model
        
        - **Architecture:** ResNet34 + LoRA adapters (rank=8)
        - **Training:** Fine-tuned on 1,120 images
        - **Performance:** 89.73% test accuracy, 89.96% F1-score
        - **Efficiency:** Only 465K trainable parameters (2.14% of model)
        
        ### How it works
        
        LoRA (Low-Rank Adaptation) freezes the pretrained ImageNet weights and adds small trainable 
        matrices to specific layers. This prevents overfitting on small datasets while maintaining 
        the model's powerful feature extraction capabilities.
        
        ### Limitations
        
        - Trained on limited dataset (1,120 images)
        - Best for frontal/profile views with visible cigarettes
        - May not generalize to all smoking scenarios
        
        ### Links
        
        - [Model Card](https://huggingface.co/notrito/smoker-detection)
        - [Training Notebook](https://www.kaggle.com/code/notrito/smoker-detection-with-lora)
        
        **Author:** Noel Triguero
        """
    )
    
    # Connect buttons
    predict_btn.click(fn=predict, inputs=input_image, outputs=output_label)
    random_btn.click(fn=get_random_sample, inputs=None, outputs=input_image)

if __name__ == "__main__":
    demo.launch()