File size: 4,476 Bytes
077fb0c
 
 
 
 
 
 
 
 
 
fdcadea
077fb0c
 
 
 
 
 
 
 
fdcadea
bb98138
077fb0c
f8ecba6
077fb0c
b67331d
ebbea61
 
 
b67331d
 
aa63283
 
 
b67331d
 
 
 
ebbea61
 
f8ecba6
ebbea61
 
 
 
 
 
 
 
f8ecba6
aa63283
ebbea61
aa63283
 
 
ebbea61
 
 
 
 
aa63283
 
 
ebbea61
aa63283
 
 
 
 
 
 
 
f8ecba6
aa63283
 
 
ebbea61
f8ecba6
ebbea61
 
 
 
 
 
 
 
 
f8ecba6
ebbea61
 
f8ecba6
aa63283
 
f8ecba6
aa63283
 
ebbea61
 
 
aa63283
ebbea61
 
 
 
aa63283
 
ebbea61
 
f8ecba6
ebbea61
 
 
 
 
 
 
b67331d
aa63283
b67331d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
"""
Inference script for ResNet50 trained on ImageNet-1K.
"""
# Standard Library Imports
import numpy as np
import torch
from collections import OrderedDict

# Third Party Imports
import spaces
from torchvision import transforms
from torch.nn import functional as F
from torchvision.models import resnet50
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


@spaces.GPU
def inference(image, alpha, top_k, target_layer, model=None, classes=None):
    """
    Run inference with GradCAM visualization
    """
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Debug: Print model mode
        print(f"Model mode: {model.training}")
        
        # Ensure model is on correct device and in eval mode
        model = model.to(device)
        model.eval()
        
        with torch.cuda.amp.autocast():
            org_img = image.copy()

            # Convert img to tensor and normalize it
            _transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])

            # Debug: Print image tensor stats
            input_tensor = _transform(image).to(device)
            print(f"Input tensor shape: {input_tensor.shape}")
            print(f"Input tensor range: [{input_tensor.min():.2f}, {input_tensor.max():.2f}]")
            
            input_tensor = input_tensor.unsqueeze(0)
            input_tensor.requires_grad = True
            
            # Get Model Predictions
            outputs = model(input_tensor)
            print(f"Raw output shape: {outputs.shape}")
            print(f"Raw output range: [{outputs.min():.2f}, {outputs.max():.2f}]")
            
            probabilities = torch.softmax(outputs, dim=1)[0]
            print(f"Probabilities sum: {probabilities.sum():.2f}")  # Should be close to 1.0
            
            # Get top 5 predictions for debugging
            top_probs, top_indices = torch.topk(probabilities, 5)
            print("\nTop 5 predictions:")
            for idx, (prob, class_idx) in enumerate(zip(top_probs, top_indices)):
                class_name = classes[class_idx]
                print(f"{idx+1}. {class_name}: {prob:.4f}")

            # Create confidence dictionary
            confidences = {classes[i]: float(probabilities[i]) for i in range(len(classes))}
            sorted_confidences = sorted(confidences.items(), key=lambda x: x[1], reverse=True)
            show_confidences = OrderedDict(sorted_confidences[:top_k])

            # Map layer numbers to meaningful parts of the ResNet architecture
            _layers = {
                1: model.conv1,
                2: model.layer1[-1],
                3: model.layer2[-1],
                4: model.layer3[-1],
                5: model.layer4[-1],
                6: model.layer4[-1]
            }

            target_layer = min(max(target_layer, 1), 6)
            target_layers = [_layers[target_layer]]

            # Debug: Print selected layer
            print(f"\nUsing target layer: {target_layers[0]}")

            cam = GradCAM(model=model, target_layers=target_layers)
            
            # Get the most probable class index
            top_class = max(confidences.items(), key=lambda x: x[1])[0]
            class_idx = classes.index(top_class)
            print(f"\nSelected class for GradCAM: {top_class} (index: {class_idx})")
            
            grayscale_cam = cam(
                input_tensor=input_tensor,
                targets=[ClassifierOutputTarget(class_idx)],
                aug_smooth=False,
                eigen_smooth=False
            )
            grayscale_cam = grayscale_cam[0, :]

            visualization = show_cam_on_image(org_img/255., grayscale_cam, use_rgb=True, image_weight=alpha)
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                
            return show_confidences, visualization
            
    except Exception as e:
        print(f"Error in inference: {str(e)}")
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise e