Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
""" | |
Application for ResNet50 trained on ImageNet-1K. | |
""" | |
# Standard Library Imports | |
import gradio as gr | |
# Third Party Imports | |
import torch | |
from torchvision import models | |
# Local Imports | |
from inference import inference | |
def load_model(model_path: str): | |
""" | |
Load the model. | |
""" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
# Initialize a fresh model without pretrained weights | |
model = models.resnet50(weights=None) | |
model = model.to(device) | |
# Load custom weights | |
state_dict = torch.load(model_path, map_location=device) | |
# Debug: Print original state dict keys | |
print("\nOriginal state dict keys:", list(state_dict['model_state_dict'].keys())[:5]) | |
# Remove the 'model.' prefix from state dict keys | |
new_state_dict = {} | |
for key, value in state_dict['model_state_dict'].items(): | |
new_key = key.replace('model.', '') | |
new_state_dict[new_key] = value | |
# Debug: Print modified state dict keys | |
print("Modified state dict keys:", list(new_state_dict.keys())[:5]) | |
print("Model state dict keys:", list(model.state_dict().keys())[:5]) | |
# Load the modified state dict | |
try: | |
model.load_state_dict(new_state_dict) | |
print("Successfully loaded model weights") | |
except Exception as e: | |
print(f"Error loading state dict: {str(e)}") | |
raise e | |
model.eval() | |
return model | |
def load_classes(): | |
""" | |
Load the ImageNet classes | |
""" | |
weights = models.ResNet50_Weights.IMAGENET1K_V1 | |
classes = weights.meta["categories"] | |
print(f"Loaded {len(classes)} classes") | |
return classes | |
def inference_wrapper(image, alpha, top_k, target_layer): | |
""" | |
Wrapper function for inference with error handling | |
""" | |
try: | |
if image is None: | |
return {"Error": 1.0}, None | |
results = inference( | |
image, | |
alpha, | |
top_k, | |
target_layer, | |
model=model, | |
classes=classes | |
) | |
if results is None: | |
return {"Error": 1.0}, None | |
return results | |
except RuntimeError as e: | |
error_msg = str(e) | |
print(f"Error in inference: {error_msg}") | |
if "out of memory" in error_msg.lower(): | |
return {"GPU Memory Error - Please try again": 1.0}, None | |
return {"Runtime Error: " + error_msg: 1.0}, None | |
except Exception as e: | |
error_msg = str(e) | |
print(f"Error in inference: {error_msg}") | |
return {"Error: " + error_msg: 1.0}, None | |
def main(): | |
""" | |
Main function for the application. | |
""" | |
global model, classes | |
try: | |
print(f"Gradio version: {gr.__version__}") | |
# Load the model at startup | |
model = load_model("resnet50_imagenet1k.pth") | |
classes = load_classes() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# ResNet50 trained on ImageNet-1K | |
A large-scale image classification dataset with 1.2 million training images across 1,000 object categories. | |
""" | |
) | |
with gr.Tab("Predictions & GradCAM"): | |
gr.Markdown( | |
""" | |
View model predictions and visualize where the model is looking using GradCAM. | |
## Steps to use: | |
1. Upload an image or select one from the examples below | |
2. Adjust the sliders (optional): | |
- Activation Map Transparency: Controls the blend between original image and activation map | |
- Number of Top Predictions: How many top class predictions to show | |
- Target Layer Number: Which network layer to visualize (deeper layers show higher-level features) | |
3. Click "Generate GradCAM" to run the model | |
4. View the results: | |
- Left: Original uploaded image | |
- Right: Model predictions and GradCAM visualization showing where the model focused | |
""" | |
) | |
# Define inputs | |
with gr.Row(): | |
img_input = gr.Image( | |
label="Input Image", | |
type="numpy", | |
height=224, | |
width=224 | |
) | |
with gr.Column(): | |
label_output = gr.Label(label="Predictions") | |
gradcam_output = gr.Image( | |
label="GradCAM Output", | |
height=224, | |
width=224 | |
) | |
with gr.Row(): | |
alpha_slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
label="Activation Map Transparency" | |
) | |
top_k_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
label="Number of Top Predictions" | |
) | |
target_layer_slider = gr.Slider( | |
minimum=1, | |
maximum=6, | |
value=4, | |
step=1, | |
label="Target Layer Number" | |
) | |
gradcam_button = gr.Button("Generate GradCAM") | |
# Set up the click event | |
gradcam_button.click( | |
fn=inference_wrapper, | |
inputs=[ | |
img_input, | |
alpha_slider, | |
top_k_slider, | |
target_layer_slider | |
], | |
outputs=[ | |
label_output, | |
gradcam_output | |
] | |
) | |
# Examples section for Gradio 5.x | |
examples = [ | |
[ | |
"assets/examples/cat.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/frog.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/bird.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/car.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/truck.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/horse.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/plane.jpg", | |
0.5, | |
3, | |
4 | |
], | |
[ | |
"assets/examples/ship.png", | |
0.5, | |
3, | |
4 | |
] | |
] | |
gr.Examples( | |
examples=examples, | |
inputs=[ | |
img_input, | |
alpha_slider, | |
top_k_slider, | |
target_layer_slider | |
], | |
outputs=[ | |
label_output, | |
gradcam_output | |
], | |
fn=inference_wrapper, | |
cache_examples=False, # Disable caching to prevent memory issues | |
label="Click on any example to run GradCAM" | |
) | |
# Queue configuration | |
demo.queue(max_size=1) # Only allow one job at a time | |
# Launch with minimal memory usage | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) | |
except Exception as e: | |
print(f"Error during startup: {str(e)}") | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
main() | |