detr_interp_space / obj_detection_DETR.py
Skier8402's picture
Upload obj_detection_DETR.py
ff7112c verified
'''
Usage of DETR with Captum for interpretability.
Demonstrates Grad-CAM and Integrated Gradients on object detection.
On random COCO image, picks a detection and visualizes attributions.
Appeals to developers and ML practitioners interested in model interpretability.
'''
import torch, requests, numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
from torchvision.transforms.functional import resize
from captum.attr import IntegratedGradients
# ---------------- 1. Load DETR ----------------
model_name = "facebook/detr-resnet-50"
model = DetrForObjectDetection.from_pretrained(model_name)
feature_extractor = DetrImageProcessor.from_pretrained(model_name)
model.eval()
# ---------------- 2. Load an image ----------------
url = "http://images.cocodataset.org/val2017/000000039769.jpg" # dog+cat
img = Image.open(requests.get(url, stream=True).raw).convert("RGB")
# ---------------- 3. Preprocess & forward ----------------
inputs = feature_extractor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"]
outputs = model(pixel_values)
target_sizes = torch.tensor([img.size[::-1]])
# use the updated post_process_object_detection API
results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0]
# ---------------- 4. Pick detection ----------------
keep = results["scores"] > 0.7
boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep]
chosen_idx = 0
chosen_label = labels[chosen_idx].item()
chosen_name = model.config.id2label[chosen_label]
score_val = float(scores[chosen_idx].detach().cpu().item()) if isinstance(scores[chosen_idx], torch.Tensor) else float(scores[chosen_idx])
print(f"Chosen detection: {chosen_name}, score={score_val:.2f}")
# ---------------- 5. Grad-CAM ----------------
# Find a suitable convolutional layer in the backbone (robust to implementation details)
backbone = getattr(model.model, "backbone", None)
conv_layer = None
if backbone is not None:
for name, module in reversed(list(backbone.named_modules())):
if isinstance(module, torch.nn.Conv2d):
conv_layer = module
conv_name = name
break
# fallback to searching entire model
if conv_layer is None:
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
conv_layer = module
conv_name = name
break
if conv_layer is None:
raise RuntimeError("No Conv2d layer found for Grad-CAM")
activations, gradients = {}, {}
def forward_hook(m, i, o): activations["value"] = o.detach()
# register_full_backward_hook is preferred where available
if hasattr(conv_layer, "register_full_backward_hook"):
conv_layer.register_forward_hook(forward_hook)
conv_layer.register_full_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()}))
else:
conv_layer.register_forward_hook(forward_hook)
conv_layer.register_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()}))
# Previously we computed outputs before registering hooks, so hooks didn't capture activations.
# Re-run a forward pass with inputs that require gradients, then backprop on the chosen detection logit.
# determine the query index corresponding to the chosen kept detection (from earlier results)
keep_idxs = torch.nonzero(keep).squeeze()
if keep_idxs.dim() == 0:
chosen_query_idx = int(keep_idxs.item())
else:
chosen_query_idx = int(keep_idxs[chosen_idx].item())
# prepare pixel_values for gradient computation and re-run forward to trigger hooks
pixel_values_for_grad = pixel_values.clone().detach().requires_grad_(True)
outputs_for_grad = model(pixel_values_for_grad)
# select the logit for that query & class and backpropagate
score_for_grad = outputs_for_grad.logits[0, chosen_query_idx, chosen_label]
model.zero_grad()
score_for_grad.backward()
# now activations and gradients should be populated by the hooks
acts = activations["value"].squeeze(0) # (C,H,W)
grads = gradients["value"].squeeze(0)
weights = grads.mean(dim=(1,2))
cam = torch.relu((weights[:,None,None] * acts).sum(0))
cam = cam / cam.max()
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0,0].numpy()
# ---------------- 6. Integrated Gradients ----------------
# pick the chosen query index (as above) and create a forward function that returns a scalar logit per input
def forward_func(pixel_values):
out = model(pixel_values=pixel_values)
# return the selected query/class logit as a 1-D tensor (batch,)
return out.logits[:, chosen_query_idx, chosen_label]
ig = IntegratedGradients(forward_func)
# since forward_func already returns a scalar logit per sample, don't pass target
attributions, _ = ig.attribute(pixel_values, n_steps=25, return_convergence_delta=True)
attr = attributions.squeeze().mean(0).cpu().detach().numpy()
attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8)
# ---------------- 7. Visualize ----------------
fig, axs = plt.subplots(1,3, figsize=(16,6))
axs[0].imshow(img); axs[0].set_title(f"Original: {chosen_name}"); axs[0].axis("off")
axs[1].imshow(img); axs[1].imshow(cam_resized, cmap="jet", alpha=0.5)
axs[1].set_title("Grad-CAM heatmap"); axs[1].axis("off")
axs[2].imshow(img); axs[2].imshow(attr, cmap="hot", alpha=0.5)
axs[2].set_title("Integrated Gradients"); axs[2].axis("off")
plt.show()