ImageNet / app.py
Shilpaj's picture
Feat: Share app
f72fe80
#!/usr/bin/env python
"""
Application for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import gradio as gr
# Third Party Imports
import torch
from torchvision import models
# Local Imports
from inference import inference
def load_model(model_path: str):
"""
Load the model.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Initialize a fresh model without pretrained weights
model = models.resnet50(weights=None)
model = model.to(device)
# Load custom weights
state_dict = torch.load(model_path, map_location=device)
# Debug: Print original state dict keys
print("\nOriginal state dict keys:", list(state_dict['model_state_dict'].keys())[:5])
# Remove the 'model.' prefix from state dict keys
new_state_dict = {}
for key, value in state_dict['model_state_dict'].items():
new_key = key.replace('model.', '')
new_state_dict[new_key] = value
# Debug: Print modified state dict keys
print("Modified state dict keys:", list(new_state_dict.keys())[:5])
print("Model state dict keys:", list(model.state_dict().keys())[:5])
# Load the modified state dict
try:
model.load_state_dict(new_state_dict)
print("Successfully loaded model weights")
except Exception as e:
print(f"Error loading state dict: {str(e)}")
raise e
model.eval()
return model
def load_classes():
"""
Load the ImageNet classes
"""
weights = models.ResNet50_Weights.IMAGENET1K_V1
classes = weights.meta["categories"]
print(f"Loaded {len(classes)} classes")
return classes
def inference_wrapper(image, alpha, top_k, target_layer):
"""
Wrapper function for inference with error handling
"""
try:
if image is None:
return {"Error": 1.0}, None
results = inference(
image,
alpha,
top_k,
target_layer,
model=model,
classes=classes
)
if results is None:
return {"Error": 1.0}, None
return results
except RuntimeError as e:
error_msg = str(e)
print(f"Error in inference: {error_msg}")
if "out of memory" in error_msg.lower():
return {"GPU Memory Error - Please try again": 1.0}, None
return {"Runtime Error: " + error_msg: 1.0}, None
except Exception as e:
error_msg = str(e)
print(f"Error in inference: {error_msg}")
return {"Error: " + error_msg: 1.0}, None
def main():
"""
Main function for the application.
"""
global model, classes
try:
print(f"Gradio version: {gr.__version__}")
# Load the model at startup
model = load_model("resnet50_imagenet1k.pth")
classes = load_classes()
with gr.Blocks() as demo:
gr.Markdown(
"""
# ResNet50 trained on ImageNet-1K
A large-scale image classification dataset with 1.2 million training images across 1,000 object categories.
"""
)
with gr.Tab("Predictions & GradCAM"):
gr.Markdown(
"""
View model predictions and visualize where the model is looking using GradCAM.
## Steps to use:
1. Upload an image or select one from the examples below
2. Adjust the sliders (optional):
- Activation Map Transparency: Controls the blend between original image and activation map
- Number of Top Predictions: How many top class predictions to show
- Target Layer Number: Which network layer to visualize (deeper layers show higher-level features)
3. Click "Generate GradCAM" to run the model
4. View the results:
- Left: Original uploaded image
- Right: Model predictions and GradCAM visualization showing where the model focused
"""
)
# Define inputs
with gr.Row():
img_input = gr.Image(
label="Input Image",
type="numpy",
height=224,
width=224
)
with gr.Column():
label_output = gr.Label(label="Predictions")
gradcam_output = gr.Image(
label="GradCAM Output",
height=224,
width=224
)
with gr.Row():
alpha_slider = gr.Slider(
minimum=0,
maximum=1,
value=0.5,
step=0.1,
label="Activation Map Transparency"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Top Predictions"
)
target_layer_slider = gr.Slider(
minimum=1,
maximum=6,
value=4,
step=1,
label="Target Layer Number"
)
gradcam_button = gr.Button("Generate GradCAM")
# Set up the click event
gradcam_button.click(
fn=inference_wrapper,
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
]
)
# Examples section for Gradio 5.x
examples = [
[
"assets/examples/cat.jpg",
0.5,
3,
4
],
[
"assets/examples/frog.jpg",
0.5,
3,
4
],
[
"assets/examples/bird.jpg",
0.5,
3,
4
],
[
"assets/examples/car.jpg",
0.5,
3,
4
],
[
"assets/examples/truck.jpg",
0.5,
3,
4
],
[
"assets/examples/horse.jpg",
0.5,
3,
4
],
[
"assets/examples/plane.jpg",
0.5,
3,
4
],
[
"assets/examples/ship.png",
0.5,
3,
4
]
]
gr.Examples(
examples=examples,
inputs=[
img_input,
alpha_slider,
top_k_slider,
target_layer_slider
],
outputs=[
label_output,
gradcam_output
],
fn=inference_wrapper,
cache_examples=False, # Disable caching to prevent memory issues
label="Click on any example to run GradCAM"
)
# Queue configuration
demo.queue(max_size=1) # Only allow one job at a time
# Launch with minimal memory usage
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
except Exception as e:
print(f"Error during startup: {str(e)}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()