Spaces:
Sleeping
Sleeping
File size: 4,476 Bytes
077fb0c fdcadea 077fb0c fdcadea bb98138 077fb0c f8ecba6 077fb0c b67331d ebbea61 b67331d aa63283 b67331d ebbea61 f8ecba6 ebbea61 f8ecba6 aa63283 ebbea61 aa63283 ebbea61 aa63283 ebbea61 aa63283 f8ecba6 aa63283 ebbea61 f8ecba6 ebbea61 f8ecba6 ebbea61 f8ecba6 aa63283 f8ecba6 aa63283 ebbea61 aa63283 ebbea61 aa63283 ebbea61 f8ecba6 ebbea61 b67331d aa63283 b67331d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
#!/usr/bin/env python
"""
Inference script for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import numpy as np
import torch
from collections import OrderedDict
# Third Party Imports
import spaces
from torchvision import transforms
from torch.nn import functional as F
from torchvision.models import resnet50
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
@spaces.GPU
def inference(image, alpha, top_k, target_layer, model=None, classes=None):
"""
Run inference with GradCAM visualization
"""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Debug: Print model mode
print(f"Model mode: {model.training}")
# Ensure model is on correct device and in eval mode
model = model.to(device)
model.eval()
with torch.cuda.amp.autocast():
org_img = image.copy()
# Convert img to tensor and normalize it
_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Debug: Print image tensor stats
input_tensor = _transform(image).to(device)
print(f"Input tensor shape: {input_tensor.shape}")
print(f"Input tensor range: [{input_tensor.min():.2f}, {input_tensor.max():.2f}]")
input_tensor = input_tensor.unsqueeze(0)
input_tensor.requires_grad = True
# Get Model Predictions
outputs = model(input_tensor)
print(f"Raw output shape: {outputs.shape}")
print(f"Raw output range: [{outputs.min():.2f}, {outputs.max():.2f}]")
probabilities = torch.softmax(outputs, dim=1)[0]
print(f"Probabilities sum: {probabilities.sum():.2f}") # Should be close to 1.0
# Get top 5 predictions for debugging
top_probs, top_indices = torch.topk(probabilities, 5)
print("\nTop 5 predictions:")
for idx, (prob, class_idx) in enumerate(zip(top_probs, top_indices)):
class_name = classes[class_idx]
print(f"{idx+1}. {class_name}: {prob:.4f}")
# Create confidence dictionary
confidences = {classes[i]: float(probabilities[i]) for i in range(len(classes))}
sorted_confidences = sorted(confidences.items(), key=lambda x: x[1], reverse=True)
show_confidences = OrderedDict(sorted_confidences[:top_k])
# Map layer numbers to meaningful parts of the ResNet architecture
_layers = {
1: model.conv1,
2: model.layer1[-1],
3: model.layer2[-1],
4: model.layer3[-1],
5: model.layer4[-1],
6: model.layer4[-1]
}
target_layer = min(max(target_layer, 1), 6)
target_layers = [_layers[target_layer]]
# Debug: Print selected layer
print(f"\nUsing target layer: {target_layers[0]}")
cam = GradCAM(model=model, target_layers=target_layers)
# Get the most probable class index
top_class = max(confidences.items(), key=lambda x: x[1])[0]
class_idx = classes.index(top_class)
print(f"\nSelected class for GradCAM: {top_class} (index: {class_idx})")
grayscale_cam = cam(
input_tensor=input_tensor,
targets=[ClassifierOutputTarget(class_idx)],
aug_smooth=False,
eigen_smooth=False
)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(org_img/255., grayscale_cam, use_rgb=True, image_weight=alpha)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return show_confidences, visualization
except Exception as e:
print(f"Error in inference: {str(e)}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise e
|