import cv2 import numpy as np from torch.nn import functional as F import torch class ActivationsAndGradients: """ Class for extracting activations and registering gradients from targeted intermediate layers """ def __init__(self, model, target_layers, reshape_transform): self.model = model self.gradients = [] self.activations = [] self.reshape_transform = reshape_transform self.handles = [] for target_layer in target_layers: self.handles.append( target_layer.register_forward_hook( self.save_activation)) # Backward compatibility with older pytorch versions: if hasattr(target_layer, 'register_full_backward_hook'): self.handles.append( target_layer.register_full_backward_hook( self.save_gradient)) else: self.handles.append( target_layer.register_backward_hook( self.save_gradient)) def save_activation(self, module, input, output): activation = output if self.reshape_transform is not None: activation = self.reshape_transform(activation) self.activations.append(activation.cpu().detach()) def save_gradient(self, module, grad_input, grad_output): # Gradients are computed in reverse order grad = grad_output[0] if self.reshape_transform is not None: grad = self.reshape_transform(grad) self.gradients = [grad.cpu().detach()] + self.gradients def __call__(self, x, y): self.gradients = [] self.activations = [] return self.model(x, y) def release(self): for handle in self.handles: handle.remove() class GradCAM: def __init__(self, cfg, model, target_layers, reshape_transform=None, use_cuda=False): self.cfg = cfg self.model = model.eval() self.target_layers = target_layers self.reshape_transform = reshape_transform self.cuda = use_cuda if self.cuda: self.model = model.cuda() self.activations_and_grads = ActivationsAndGradients( self.model, target_layers, reshape_transform) """ Get a vector of weights for every channel in the target layer. Methods that return weights channels, will typically need to only implement this function. """ @staticmethod def get_cam_weights(grads): return np.mean(grads, axis=(2, 3), keepdims=True) @staticmethod def get_loss(output, target_category): loss = 0 for i in range(len(target_category)): loss = loss + output[i] return loss def get_cam_image(self, activations, grads): weights = self.get_cam_weights(grads) weighted_activations = weights * activations cam = weighted_activations.sum(axis=1) return cam @staticmethod def get_target_width_height(input_tensor): width, height = input_tensor.size(-1), input_tensor.size(-2) return width, height def compute_cam_per_layer(self, input_tensor): activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations] grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients] target_size = self.get_target_width_height(input_tensor) cam_per_target_layer = [] # Loop over the saliency image from every layer for layer_activations, layer_grads in zip(activations_list, grads_list): cam = self.get_cam_image(layer_activations, layer_grads) cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image scaled = self.scale_cam_image(cam, target_size) cam_per_target_layer.append(scaled[:, None, :]) return cam_per_target_layer def aggregate_multi_layers(self, cam_per_target_layer): cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) cam_per_target_layer = np.maximum(cam_per_target_layer, 0) result = np.mean(cam_per_target_layer, axis=1) return self.scale_cam_image(result) @staticmethod def scale_cam_image(cam, target_size=None): result = [] for img in cam: img = img - np.min(img) img = img / (1e-7 + np.max(img)) if target_size is not None: img = cv2.resize(img, target_size) result.append(img) result = np.float32(result) return result def __call__(self, input_tensor, target_category=None): x, y = input_tensor if self.cuda: x = x.cuda() y = y.cuda() # 正向传播得到网络输出logits(未经过softmax) if self.cfg.net == 'cdmask': o, outputs = self.activations_and_grads(x, y) mask_cls_results = outputs["pred_logits"] mask_pred_results = outputs["pred_masks"] mask_pred_results = F.interpolate( mask_pred_results, scale_factor=(4,4), mode="bilinear", align_corners=False, ) mask_cls = F.softmax(mask_cls_results, dim=-1)[...,1:] mask_pred = mask_pred_results.sigmoid() output = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) else: output = self.activations_and_grads(x, y) if isinstance(target_category, int): target_category = [target_category] * x.size(0) if target_category is None: target_category = np.argmax(output.cpu().data.numpy(), axis=-1) print(f"category id: {target_category}") else: assert (len(target_category) == x.size(0)) self.model.zero_grad() loss = self.get_loss(output, target_category).sum() loss.backward(retain_graph=True) # In most of the saliency attribution papers, the saliency is # computed with a single target layer. # Commonly it is the last convolutional layer. # Here we support passing a list with multiple target layers. # It will compute the saliency image for every image, # and then aggregate them (with a default mean aggregation). # This gives you more flexibility in case you just want to # use all conv layers for example, all Batchnorm layers, # or something else. cam_per_layer = self.compute_cam_per_layer(x) return self.aggregate_multi_layers(cam_per_layer) def __del__(self): self.activations_and_grads.release() def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_tb): self.activations_and_grads.release() if isinstance(exc_value, IndexError): # Handle IndexError here... print( f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") return True def show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET) -> np.ndarray: """ This function overlays the cam mask on the image as an heatmap. By default the heatmap is in BGR format. :param img: The base image in RGB or BGR format. :param mask: The cam mask. :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. :param colormap: The OpenCV colormap to be used. :returns: The default image with the cam overlay. """ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) if use_rgb: heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) heatmap = np.float32(heatmap) / 255 if np.max(img) > 1: raise Exception( "The input image should np.float32 in the range [0, 1]") cam = heatmap + img cam = cam / np.max(cam) return np.uint8(255 * cam) def center_crop_img(img: np.ndarray, size: int): h, w, c = img.shape if w == h == size: return img if w < h: ratio = size / w new_w = size new_h = int(h * ratio) else: ratio = size / h new_h = size new_w = int(w * ratio) img = cv2.resize(img, dsize=(new_w, new_h)) if new_w == size: h = (new_h - size) // 2 img = img[h: h+size] else: w = (new_w - size) // 2 img = img[:, w: w+size] return img