''' Usage of DETR with Captum for interpretability. Demonstrates Grad-CAM and Integrated Gradients on object detection. On random COCO image, picks a detection and visualizes attributions. Appeals to developers and ML practitioners interested in model interpretability. ''' import torch, requests, numpy as np import matplotlib.pyplot as plt from PIL import Image from transformers import DetrImageProcessor, DetrForObjectDetection from torchvision.transforms.functional import resize from captum.attr import IntegratedGradients # ---------------- 1. Load DETR ---------------- model_name = "facebook/detr-resnet-50" model = DetrForObjectDetection.from_pretrained(model_name) feature_extractor = DetrImageProcessor.from_pretrained(model_name) model.eval() # ---------------- 2. Load an image ---------------- url = "http://images.cocodataset.org/val2017/000000039769.jpg" # dog+cat img = Image.open(requests.get(url, stream=True).raw).convert("RGB") # ---------------- 3. Preprocess & forward ---------------- inputs = feature_extractor(images=img, return_tensors="pt") pixel_values = inputs["pixel_values"] outputs = model(pixel_values) target_sizes = torch.tensor([img.size[::-1]]) # use the updated post_process_object_detection API results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0] # ---------------- 4. Pick detection ---------------- keep = results["scores"] > 0.7 boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep] chosen_idx = 0 chosen_label = labels[chosen_idx].item() chosen_name = model.config.id2label[chosen_label] score_val = float(scores[chosen_idx].detach().cpu().item()) if isinstance(scores[chosen_idx], torch.Tensor) else float(scores[chosen_idx]) print(f"Chosen detection: {chosen_name}, score={score_val:.2f}") # ---------------- 5. Grad-CAM ---------------- # Find a suitable convolutional layer in the backbone (robust to implementation details) backbone = getattr(model.model, "backbone", None) conv_layer = None if backbone is not None: for name, module in reversed(list(backbone.named_modules())): if isinstance(module, torch.nn.Conv2d): conv_layer = module conv_name = name break # fallback to searching entire model if conv_layer is None: for name, module in reversed(list(model.named_modules())): if isinstance(module, torch.nn.Conv2d): conv_layer = module conv_name = name break if conv_layer is None: raise RuntimeError("No Conv2d layer found for Grad-CAM") activations, gradients = {}, {} def forward_hook(m, i, o): activations["value"] = o.detach() # register_full_backward_hook is preferred where available if hasattr(conv_layer, "register_full_backward_hook"): conv_layer.register_forward_hook(forward_hook) conv_layer.register_full_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()})) else: conv_layer.register_forward_hook(forward_hook) conv_layer.register_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()})) # Previously we computed outputs before registering hooks, so hooks didn't capture activations. # Re-run a forward pass with inputs that require gradients, then backprop on the chosen detection logit. # determine the query index corresponding to the chosen kept detection (from earlier results) keep_idxs = torch.nonzero(keep).squeeze() if keep_idxs.dim() == 0: chosen_query_idx = int(keep_idxs.item()) else: chosen_query_idx = int(keep_idxs[chosen_idx].item()) # prepare pixel_values for gradient computation and re-run forward to trigger hooks pixel_values_for_grad = pixel_values.clone().detach().requires_grad_(True) outputs_for_grad = model(pixel_values_for_grad) # select the logit for that query & class and backpropagate score_for_grad = outputs_for_grad.logits[0, chosen_query_idx, chosen_label] model.zero_grad() score_for_grad.backward() # now activations and gradients should be populated by the hooks acts = activations["value"].squeeze(0) # (C,H,W) grads = gradients["value"].squeeze(0) weights = grads.mean(dim=(1,2)) cam = torch.relu((weights[:,None,None] * acts).sum(0)) cam = cam / cam.max() cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0,0].numpy() # ---------------- 6. Integrated Gradients ---------------- # pick the chosen query index (as above) and create a forward function that returns a scalar logit per input def forward_func(pixel_values): out = model(pixel_values=pixel_values) # return the selected query/class logit as a 1-D tensor (batch,) return out.logits[:, chosen_query_idx, chosen_label] ig = IntegratedGradients(forward_func) # since forward_func already returns a scalar logit per sample, don't pass target attributions, _ = ig.attribute(pixel_values, n_steps=25, return_convergence_delta=True) attr = attributions.squeeze().mean(0).cpu().detach().numpy() attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8) # ---------------- 7. Visualize ---------------- fig, axs = plt.subplots(1,3, figsize=(16,6)) axs[0].imshow(img); axs[0].set_title(f"Original: {chosen_name}"); axs[0].axis("off") axs[1].imshow(img); axs[1].imshow(cam_resized, cmap="jet", alpha=0.5) axs[1].set_title("Grad-CAM heatmap"); axs[1].axis("off") axs[2].imshow(img); axs[2].imshow(attr, cmap="hot", alpha=0.5) axs[2].set_title("Integrated Gradients"); axs[2].axis("off") plt.show()