Spaces:
Running
Running
''' | |
Usage of DETR with Captum for interpretability. | |
Demonstrates Grad-CAM and Integrated Gradients on object detection. | |
On random COCO image, picks a detection and visualizes attributions. | |
Appeals to developers and ML practitioners interested in model interpretability. | |
''' | |
import torch, requests, numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
from torchvision.transforms.functional import resize | |
from captum.attr import IntegratedGradients | |
# ---------------- 1. Load DETR ---------------- | |
model_name = "facebook/detr-resnet-50" | |
model = DetrForObjectDetection.from_pretrained(model_name) | |
feature_extractor = DetrImageProcessor.from_pretrained(model_name) | |
model.eval() | |
# ---------------- 2. Load an image ---------------- | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" # dog+cat | |
img = Image.open(requests.get(url, stream=True).raw).convert("RGB") | |
# ---------------- 3. Preprocess & forward ---------------- | |
inputs = feature_extractor(images=img, return_tensors="pt") | |
pixel_values = inputs["pixel_values"] | |
outputs = model(pixel_values) | |
target_sizes = torch.tensor([img.size[::-1]]) | |
# use the updated post_process_object_detection API | |
results = feature_extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0] | |
# ---------------- 4. Pick detection ---------------- | |
keep = results["scores"] > 0.7 | |
boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep] | |
chosen_idx = 0 | |
chosen_label = labels[chosen_idx].item() | |
chosen_name = model.config.id2label[chosen_label] | |
score_val = float(scores[chosen_idx].detach().cpu().item()) if isinstance(scores[chosen_idx], torch.Tensor) else float(scores[chosen_idx]) | |
print(f"Chosen detection: {chosen_name}, score={score_val:.2f}") | |
# ---------------- 5. Grad-CAM ---------------- | |
# Find a suitable convolutional layer in the backbone (robust to implementation details) | |
backbone = getattr(model.model, "backbone", None) | |
conv_layer = None | |
if backbone is not None: | |
for name, module in reversed(list(backbone.named_modules())): | |
if isinstance(module, torch.nn.Conv2d): | |
conv_layer = module | |
conv_name = name | |
break | |
# fallback to searching entire model | |
if conv_layer is None: | |
for name, module in reversed(list(model.named_modules())): | |
if isinstance(module, torch.nn.Conv2d): | |
conv_layer = module | |
conv_name = name | |
break | |
if conv_layer is None: | |
raise RuntimeError("No Conv2d layer found for Grad-CAM") | |
activations, gradients = {}, {} | |
def forward_hook(m, i, o): activations["value"] = o.detach() | |
# register_full_backward_hook is preferred where available | |
if hasattr(conv_layer, "register_full_backward_hook"): | |
conv_layer.register_forward_hook(forward_hook) | |
conv_layer.register_full_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()})) | |
else: | |
conv_layer.register_forward_hook(forward_hook) | |
conv_layer.register_backward_hook(lambda m, gi, go: gradients.update({"value": go[0].detach()})) | |
# Previously we computed outputs before registering hooks, so hooks didn't capture activations. | |
# Re-run a forward pass with inputs that require gradients, then backprop on the chosen detection logit. | |
# determine the query index corresponding to the chosen kept detection (from earlier results) | |
keep_idxs = torch.nonzero(keep).squeeze() | |
if keep_idxs.dim() == 0: | |
chosen_query_idx = int(keep_idxs.item()) | |
else: | |
chosen_query_idx = int(keep_idxs[chosen_idx].item()) | |
# prepare pixel_values for gradient computation and re-run forward to trigger hooks | |
pixel_values_for_grad = pixel_values.clone().detach().requires_grad_(True) | |
outputs_for_grad = model(pixel_values_for_grad) | |
# select the logit for that query & class and backpropagate | |
score_for_grad = outputs_for_grad.logits[0, chosen_query_idx, chosen_label] | |
model.zero_grad() | |
score_for_grad.backward() | |
# now activations and gradients should be populated by the hooks | |
acts = activations["value"].squeeze(0) # (C,H,W) | |
grads = gradients["value"].squeeze(0) | |
weights = grads.mean(dim=(1,2)) | |
cam = torch.relu((weights[:,None,None] * acts).sum(0)) | |
cam = cam / cam.max() | |
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0,0].numpy() | |
# ---------------- 6. Integrated Gradients ---------------- | |
# pick the chosen query index (as above) and create a forward function that returns a scalar logit per input | |
def forward_func(pixel_values): | |
out = model(pixel_values=pixel_values) | |
# return the selected query/class logit as a 1-D tensor (batch,) | |
return out.logits[:, chosen_query_idx, chosen_label] | |
ig = IntegratedGradients(forward_func) | |
# since forward_func already returns a scalar logit per sample, don't pass target | |
attributions, _ = ig.attribute(pixel_values, n_steps=25, return_convergence_delta=True) | |
attr = attributions.squeeze().mean(0).cpu().detach().numpy() | |
attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-8) | |
# ---------------- 7. Visualize ---------------- | |
fig, axs = plt.subplots(1,3, figsize=(16,6)) | |
axs[0].imshow(img); axs[0].set_title(f"Original: {chosen_name}"); axs[0].axis("off") | |
axs[1].imshow(img); axs[1].imshow(cam_resized, cmap="jet", alpha=0.5) | |
axs[1].set_title("Grad-CAM heatmap"); axs[1].axis("off") | |
axs[2].imshow(img); axs[2].imshow(attr, cmap="hot", alpha=0.5) | |
axs[2].set_title("Integrated Gradients"); axs[2].axis("off") | |
plt.show() | |