File size: 8,930 Bytes
077fb0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ecba6
 
 
de50636
 
f8ecba6
077fb0c
ebbea61
f8ecba6
ebbea61
de50636
 
ebbea61
de50636
 
 
 
 
ebbea61
de50636
 
 
ebbea61
de50636
 
 
 
 
 
 
ebbea61
de50636
077fb0c
 
 
 
 
de50636
077fb0c
de50636
63e2fc6
de50636
077fb0c
 
 
f8ecba6
 
 
 
 
 
ebbea61
f8ecba6
b67331d
 
 
 
 
 
 
 
 
 
ebbea61
b67331d
 
 
ebbea61
5177d9a
 
 
ebbea61
 
 
5177d9a
ebbea61
 
 
5177d9a
f8ecba6
 
077fb0c
 
 
 
b67331d
f8ecba6
b67331d
35e684f
 
b67331d
 
 
077fb0c
b67331d
671ad7d
 
aa6486e
 
671ad7d
 
b67331d
aa6486e
b67331d
 
aa6486e
 
 
 
 
 
 
 
 
 
 
 
b67331d
f8ecba6
b67331d
 
 
 
 
 
f8ecba6
 
 
b67331d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671ad7d
b67331d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ecba6
671ad7d
73329bc
aea42cc
 
73329bc
 
 
 
aea42cc
 
73329bc
 
 
 
aea42cc
 
73329bc
 
 
 
aea42cc
 
73329bc
 
 
 
b67331d
aea42cc
73329bc
 
 
 
aea42cc
 
73329bc
 
 
 
aea42cc
 
73329bc
 
 
 
aea42cc
 
73329bc
 
 
 
aea42cc
 
 
 
 
b67331d
 
 
 
 
 
 
 
 
 
 
ebbea61
b67331d
 
671ad7d
aea42cc
ebbea61
8bf4957
ebbea61
b67331d
 
 
f72fe80
671ad7d
35e684f
b67331d
 
 
 
077fb0c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr

# Third Party Imports
import torch
from torchvision import models

# Local Imports
from inference import inference


def load_model(model_path: str):
    """
    Load the model.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize a fresh model without pretrained weights
    model = models.resnet50(weights=None)
    model = model.to(device)

    # Load custom weights
    state_dict = torch.load(model_path, map_location=device)
    
    # Debug: Print original state dict keys
    print("\nOriginal state dict keys:", list(state_dict['model_state_dict'].keys())[:5])
    
    # Remove the 'model.' prefix from state dict keys
    new_state_dict = {}
    for key, value in state_dict['model_state_dict'].items():
        new_key = key.replace('model.', '')
        new_state_dict[new_key] = value
    
    # Debug: Print modified state dict keys
    print("Modified state dict keys:", list(new_state_dict.keys())[:5])
    print("Model state dict keys:", list(model.state_dict().keys())[:5])
    
    # Load the modified state dict
    try:
        model.load_state_dict(new_state_dict)
        print("Successfully loaded model weights")
    except Exception as e:
        print(f"Error loading state dict: {str(e)}")
        raise e
    
    model.eval()
    return model


def load_classes():
    """
    Load the ImageNet classes
    """
    weights = models.ResNet50_Weights.IMAGENET1K_V1
    classes = weights.meta["categories"]
    print(f"Loaded {len(classes)} classes")
    return classes


def inference_wrapper(image, alpha, top_k, target_layer):
    """
    Wrapper function for inference with error handling
    """
    try:
        if image is None:
            return {"Error": 1.0}, None
            
        results = inference(
            image, 
            alpha, 
            top_k, 
            target_layer, 
            model=model, 
            classes=classes
        )
        
        if results is None:
            return {"Error": 1.0}, None
            
        return results
        
    except RuntimeError as e:
        error_msg = str(e)
        print(f"Error in inference: {error_msg}")
        
        if "out of memory" in error_msg.lower():
            return {"GPU Memory Error - Please try again": 1.0}, None
        return {"Runtime Error: " + error_msg: 1.0}, None
        
    except Exception as e:
        error_msg = str(e)
        print(f"Error in inference: {error_msg}")
        return {"Error: " + error_msg: 1.0}, None


def main():
    """
    Main function for the application.
    """
    global model, classes
    
    try:
        print(f"Gradio version: {gr.__version__}")
        
        # Load the model at startup
        model = load_model("resnet50_imagenet1k.pth")
        classes = load_classes()

        with gr.Blocks() as demo:
            gr.Markdown(
                """
                # ResNet50 trained on ImageNet-1K
                A large-scale image classification dataset with 1.2 million training images across 1,000 object categories.
                """
            )

            with gr.Tab("Predictions & GradCAM"):
                gr.Markdown(
                    """
                    View model predictions and visualize where the model is looking using GradCAM.
                    
                    ## Steps to use:
                    1. Upload an image or select one from the examples below
                    2. Adjust the sliders (optional):
                       - Activation Map Transparency: Controls the blend between original image and activation map
                       - Number of Top Predictions: How many top class predictions to show
                       - Target Layer Number: Which network layer to visualize (deeper layers show higher-level features)
                    3. Click "Generate GradCAM" to run the model
                    4. View the results:
                       - Left: Original uploaded image
                       - Right: Model predictions and GradCAM visualization showing where the model focused
                    """
                )
                
                # Define inputs
                with gr.Row():
                    img_input = gr.Image(
                        label="Input Image",
                        type="numpy",
                        height=224,
                        width=224
                    )
                    with gr.Column():
                        label_output = gr.Label(label="Predictions")
                        gradcam_output = gr.Image(
                            label="GradCAM Output",
                            height=224,
                            width=224
                        )

                with gr.Row():
                    alpha_slider = gr.Slider(
                        minimum=0,
                        maximum=1,
                        value=0.5,
                        step=0.1,
                        label="Activation Map Transparency"
                    )
                    top_k_slider = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=3,
                        step=1,
                        label="Number of Top Predictions"
                    )
                    target_layer_slider = gr.Slider(
                        minimum=1,
                        maximum=6,
                        value=4,
                        step=1,
                        label="Target Layer Number"
                    )

                gradcam_button = gr.Button("Generate GradCAM")
                
                # Set up the click event
                gradcam_button.click(
                    fn=inference_wrapper,
                    inputs=[
                        img_input,
                        alpha_slider,
                        top_k_slider,
                        target_layer_slider
                    ],
                    outputs=[
                        label_output,
                        gradcam_output
                    ]
                )

                # Examples section for Gradio 5.x
                examples = [
                    [
                        "assets/examples/cat.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/frog.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/bird.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/car.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/truck.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/horse.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/plane.jpg",
                        0.5,
                        3,
                        4
                    ],
                    [
                        "assets/examples/ship.png",
                        0.5,
                        3,
                        4
                    ]
                ]

                gr.Examples(
                    examples=examples,
                    inputs=[
                        img_input,
                        alpha_slider,
                        top_k_slider,
                        target_layer_slider
                    ],
                    outputs=[
                        label_output,
                        gradcam_output
                    ],
                    fn=inference_wrapper,
                    cache_examples=False,  # Disable caching to prevent memory issues
                    label="Click on any example to run GradCAM"
                )

            # Queue configuration
            demo.queue(max_size=1)  # Only allow one job at a time

            # Launch with minimal memory usage
            demo.launch(
                server_name="0.0.0.0",
                server_port=7860,
                share=True
            )
            
    except Exception as e:
        print(f"Error during startup: {str(e)}")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()


if __name__ == "__main__":
    main()