Spaces:
Running
Running
''' | |
this is a combined script that implements DETR object detection with interpretability methods | |
using Grad-CAM, Grad-CAM++, Integrated Gradients, and Monte Carlo Dropout for uncertainty estimation. | |
It provides a Gradio-based web interface for users to upload images, select detected objects | |
and visualize explanations and uncertainty maps. | |
How to run it: | |
```python | |
python detr_and_interp.py | |
``` | |
''' | |
import torch, requests, numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
from PIL import Image, ImageFilter | |
import gradio as gr | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
from torchvision.transforms.functional import resize | |
from captum.attr import IntegratedGradients | |
import torch.nn.functional as F | |
import logging | |
import os | |
from datetime import datetime | |
# ---------- Logging Setup ---------- | |
log_dir = os.path.join(os.path.dirname(__file__), "logs") | |
os.makedirs(log_dir, exist_ok=True) | |
log_file = os.path.join(log_dir, f"detr_interp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s', | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
logger.info("Starting DETR Interpretability Dashboard") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
model_name = "facebook/detr-resnet-50" | |
logger.info(f"Loading model: {model_name}") | |
model = DetrForObjectDetection.from_pretrained(model_name).to(device) | |
extractor = DetrImageProcessor.from_pretrained(model_name) | |
model.eval() | |
logger.info("Model loaded and set to evaluation mode") | |
# ---------- Grad-CAM / Grad-CAM++ ---------- | |
def gradcam(img, det_idx, keep, pixel_values, use_pp=False): | |
""" | |
Compute Grad-CAM (or Grad-CAM++) heatmap for a selected detection. | |
What it computes: | |
- Captures feature-map activations from a late conv layer and the gradients of the | |
detection score w.r.t. those activations. Channel-wise weights are computed from | |
gradients and used to combine feature maps into a spatial heatmap. | |
Why this matters: | |
- Highlights which spatial regions the model used to make the prediction. Useful to | |
check whether the detector is attending to the object vs irrelevant background. | |
How to interpret results: | |
- High values in the returned heatmap indicate regions that contributed positively to | |
the detection score. Grad-CAM++ (use_pp=True) computes a refined weighting that often | |
yields sharper, better-localized maps when multiple instances overlap. | |
Caveats & tips: | |
- Choosing a layer too early will give fine-grained but semantically weak maps; too late | |
will be coarse. We pick a late backbone conv block (layer4[-1]) as a sensible default. | |
- Hooks must be removed after use to avoid memory leaks; we do that below. | |
References: | |
- Selvaraju et al., Grad-CAM (2017): https://arxiv.org/abs/1610.02391 | |
""" | |
logger.info(f"Running {'Grad-CAM++' if use_pp else 'Grad-CAM'} for detection {det_idx}") | |
try: | |
# pick a late conv layer that still retains spatial info | |
conv_layer = model.model.backbone.conv_encoder.model.layer4[-1] | |
activations, gradients = {}, {} | |
def fwd(m, i, o): | |
activations["v"] = o.detach() | |
def bwd(m, gi, go): | |
gradients["v"] = go[0].detach() | |
h1 = conv_layer.register_forward_hook(fwd) | |
h2 = conv_layer.register_full_backward_hook(bwd) if hasattr(conv_layer, "register_full_backward_hook") else conv_layer.register_backward_hook(bwd) | |
logger.debug("Hooks registered for Grad-CAM") | |
outputs_for_attr = model(pixel_values) | |
logits = outputs_for_attr.logits | |
labels = logits.argmax(-1).squeeze(0) | |
label_id = labels[keep.nonzero()[det_idx]].item() | |
score = logits[0, keep.nonzero()[det_idx], label_id] | |
logger.debug(f"Target label_id: {label_id}, score: {score.item():.4f}") | |
model.zero_grad() | |
score.backward() | |
acts = activations["v"].squeeze(0) | |
grads = gradients["v"].squeeze(0) | |
logger.debug(f"Activations shape: {acts.shape}, Gradients shape: {grads.shape}") | |
if use_pp: # Grad-CAM++ | |
weights = (grads ** 2).mean(dim=(1, 2)) / (2 * (grads ** 2).mean(dim=(1, 2)) + (acts * grads ** 3).mean(dim=(1, 2)) + 1e-8) | |
else: # vanilla Grad-CAM | |
weights = grads.mean(dim=(1, 2)) | |
cam = torch.relu((weights[:, None, None] * acts).sum(0)) | |
cam = cam / (cam.max() + 1e-8) | |
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy() | |
h1.remove(); h2.remove() | |
logger.info(f"{'Grad-CAM++' if use_pp else 'Grad-CAM'} completed successfully") | |
return cam_resized | |
except Exception as e: | |
logger.error(f"Error in gradcam: {str(e)}", exc_info=True) | |
raise | |
# ---------- Integrated Gradients ---------- | |
def integrated_grad(img, det_idx, keep, outputs_for_attr, pixel_values, baseline="black"): | |
""" | |
Compute Integrated Gradients attribution map for a detection's logit. | |
What it computes: | |
- Integrates gradients along a path from a baseline input to the real input in embedding | |
space, producing per-pixel (or per-channel) attributions. | |
Why baseline choice matters: | |
- The baseline defines what the model should consider as 'no signal'. Common choices: | |
black (zeros), a blurred version of the image, or a neutral/mean image. Different | |
baselines highlight different aspects of the input. | |
How to read the output: | |
- Values > 0 indicate pixels that increase the detection logit vs baseline; values < 0 | |
reduce it. We normalize the result to [0,1] for visualization convenience. | |
Tips: | |
- Increase n_steps for smoother attributions (costlier). Check convergence_delta to | |
validate IG's completeness property. | |
References: | |
- Distill article on baselines: https://distill.pub/2020/attribution-baselines | |
- Captum IntegratedGradients docs: https://captum.ai/api/integrated_gradients.html | |
""" | |
logger.info(f"Running Integrated Gradients with {baseline} baseline for detection {det_idx}") | |
try: | |
logits = outputs_for_attr.logits | |
labels = logits.argmax(-1).squeeze(0) | |
label_id = labels[keep.nonzero()[det_idx]].item() | |
logger.debug(f"IG target label_id: {label_id}") | |
# Baselines | |
if baseline == "black": | |
base = torch.zeros_like(pixel_values) | |
logger.debug("Using black baseline") | |
elif baseline == "blur": | |
blur = img.filter(ImageFilter.GaussianBlur(radius=15)) | |
base = extractor(images=blur, return_tensors="pt")["pixel_values"].to(device) | |
logger.debug("Using blurred baseline") | |
else: | |
base = torch.zeros_like(pixel_values) | |
logger.debug("Defaulting to black baseline") | |
def forward_func(pix): | |
return model(pix).logits[:, keep.nonzero()[det_idx], label_id] | |
ig = IntegratedGradients(forward_func) | |
attr, _ = ig.attribute(pixel_values, baselines=base, n_steps=25, return_convergence_delta=True) | |
arr = attr.squeeze().mean(0).cpu().detach().numpy() | |
logger.info(f"Integrated Gradients with {baseline} baseline completed") | |
return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8) | |
except Exception as e: | |
logger.error(f"Error in integrated_grad: {str(e)}", exc_info=True) | |
raise | |
# ---------- Monte Carlo Dropout Uncertainty ---------- | |
def mc_dropout_uncertainty(img, det_idx, keep, pixel_values, n_samples=20, dropout_p=0.1): | |
""" | |
Estimate uncertainty by running multiple stochastic forward passes with dropout active. | |
What it computes: | |
- Runs the model multiple times with dropout enabled and computes a CAM per run. | |
- Returns the per-pixel mean and standard deviation across CAMs. High std indicates | |
the model's focus is unstable across stochastic perturbations. | |
Why this helps: | |
- If heatmaps vary a lot, the interpretability output is less reliable. Use this to flag | |
detections where explanations may not be trustworthy. | |
Practical tips: | |
- Increasing n_samples reduces variance in the estimate but increases runtime. | |
- Temporarily sets the model to train mode to activate dropout modules; restores eval mode. | |
""" | |
logger.info(f"Running MC Dropout uncertainty: samples={n_samples}, p={dropout_p}, detection={det_idx}") | |
try: | |
def enable_dropout(m): | |
if isinstance(m, torch.nn.Dropout): | |
m.train() | |
model.train() | |
model.apply(enable_dropout) | |
cams = [] | |
conv_layer = model.model.backbone.conv_encoder.model.layer4[-1] | |
for i in range(n_samples): | |
outputs = model(pixel_values) | |
logits = outputs.logits | |
labels = logits.argmax(-1).squeeze(0) | |
label_id = labels[keep.nonzero()[det_idx]].item() | |
score = logits[0, keep.nonzero()[det_idx], label_id] | |
acts, grads = {}, {} | |
def fwd(m, i, o): | |
acts['v'] = o.detach() | |
def bwd(m, gi, go): | |
grads['v'] = go[0].detach() | |
h1 = conv_layer.register_forward_hook(fwd) | |
h2 = (conv_layer.register_full_backward_hook(bwd) | |
if hasattr(conv_layer, 'register_full_backward_hook') | |
else conv_layer.register_backward_hook(bwd)) | |
model.zero_grad() | |
score.backward(retain_graph=False) | |
if 'v' not in acts: | |
logger.warning(f"No activations captured in sample {i}, using fallback zero map") | |
cam_resized = np.zeros((img.size[1], img.size[0])) | |
else: | |
act = acts['v'].squeeze(0) | |
grad = grads['v'].squeeze(0) | |
weights = grad.mean(dim=(1, 2)) | |
cam = torch.relu((weights[:, None, None] * act).sum(0)) | |
cam = cam / (cam.max() + 1e-8) | |
cam_resized = resize(cam.unsqueeze(0).unsqueeze(0), img.size[::-1])[0, 0].cpu().numpy() | |
cams.append(cam_resized) | |
h1.remove(); h2.remove() | |
model.eval() | |
if len(cams) == 0: | |
logger.error("No valid CAM maps generated") | |
return np.zeros((img.size[1], img.size[0])), np.zeros((img.size[1], img.size[0])) | |
cams_arr = np.stack(cams, axis=0) | |
mean_map = cams_arr.mean(0) | |
std_map = cams_arr.std(0) | |
mean_map = (mean_map - mean_map.min()) / (mean_map.max() - mean_map.min() + 1e-8) | |
std_map = (std_map - std_map.min()) / (std_map.max() - std_map.min() + 1e-8) | |
logger.info("MC Dropout uncertainty completed") | |
return mean_map, std_map | |
except Exception as e: | |
logger.error(f"Error in mc_dropout_uncertainty: {str(e)}", exc_info=True) | |
model.eval() | |
raise | |
# ---------- Full pipeline ---------- | |
def interpret(img, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p): | |
logger.info(f"Starting interpretation - detection: {det_choice}, threshold: {conf_thresh}, cam: {cam_variant}, mc_samples: {mc_samples}, dropout_p: {dropout_p}") | |
try: | |
inputs = extractor(images=img, return_tensors="pt").to(device) | |
with torch.no_grad(): outputs = model(**inputs) | |
pixel_values_attr = inputs["pixel_values"].clone().requires_grad_(True) | |
target_sizes = [img.size[::-1]] | |
results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0] | |
keep = results["scores"] > conf_thresh | |
labels, scores = results["labels"][keep], results["scores"][keep] | |
logger.info(f"Found {len(labels)} detections above threshold {conf_thresh}") | |
if len(labels) == 0: | |
logger.warning("No detections found above threshold") | |
return None, "No detections above threshold", None, "" | |
if det_choice is None: | |
det_idx = 0 | |
else: | |
try: det_idx = int(str(det_choice).split(":")[0]) | |
except: det_idx = 0 | |
label = model.config.id2label[labels[det_idx].item()] | |
logger.info(f"Selected detection {det_idx}: {label}") | |
# Grad-CAM / Grad-CAM++ (single deterministic pass) | |
cam = gradcam(img, det_idx, keep, pixel_values_attr, use_pp=(cam_variant=="Grad-CAM++")) | |
fig1, ax1 = plt.subplots(); ax1.imshow(img); ax1.imshow(cam, cmap="jet", alpha=0.5); ax1.axis("off") | |
ax1.set_title(f"{cam_variant}: {label}"); plt.close(fig1) | |
logger.debug(f"{cam_variant} visualization created") | |
# MC Dropout Uncertainty analysis | |
mean_map, std_map = mc_dropout_uncertainty(img, det_idx, keep, pixel_values_attr, n_samples=int(mc_samples), dropout_p=float(dropout_p)) | |
# Create a composite figure: mean map and std map side-by-side | |
fig2, axes = plt.subplots(1,2, figsize=(8,4)) | |
axes[0].imshow(img); axes[0].imshow(mean_map, cmap='hot', alpha=0.5); axes[0].axis('off'); axes[0].set_title('Predictive Mean') | |
axes[1].imshow(img); axes[1].imshow(std_map, cmap='viridis', alpha=0.5); axes[1].axis('off'); axes[1].set_title('Predictive Std (Uncertainty)') | |
plt.close(fig2) | |
logger.debug("MC Dropout uncertainty visualization created") | |
exp1 = f"π {cam_variant}:\nGradient-weighted feature maps β highlights where DETR focused." | |
exp2 = f"π MC Dropout Uncertainty:\nSamples={mc_samples}, dropout={dropout_p}. Shows predictive mean and per-pixel std as uncertainty." | |
logger.info("Interpretation completed successfully") | |
return fig1, exp1, fig2, exp2 | |
except Exception as e: | |
logger.error(f"Error in interpret function: {str(e)}", exc_info=True) | |
return None, f"Error: {str(e)}", None, "" | |
# ---------- Gradio UI ---------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## π§ DETR Interpretability Dashboard with Controls") | |
gr.Markdown( | |
""" | |
**How to use this dashboard** | |
- Upload an image using the left panel. The model will run object detection and list detected objects. Try [imageNet](https://www.image-net.org/) | |
- Use the "Confidence Threshold" slider to filter detections by score. Detections below the threshold are hidden. | |
- Pick a detection from the dropdown to generate explanations for that object. | |
- Choose between `Grad-CAM` and `Grad-CAM++` (Grad-CAM++ often gives sharper, more localized maps). | |
- `MC Dropout Samples` controls how many stochastic forward passes are used to estimate prediction uncertainty. More samples give smoother estimates but take longer. | |
- `Dropout Probability` sets the dropout rate used during MC Dropout; higher values typically increase predicted uncertainty. | |
Tooltips are provided on each control (hover or focus) for quick hints. | |
""" | |
) | |
with gr.Row(): | |
img_in = gr.Image(type="pil", label="Upload an image") | |
det_out = gr.Label(label="Detections") | |
det_fig = gr.Plot(label="Detections visualization") | |
det_choice = gr.Dropdown(label="Pick a detection for explanation") | |
with gr.Row(): | |
conf_thresh = gr.Slider(0, 1, value=0.7, step=0.05, label="Confidence Threshold") | |
cam_variant = gr.Radio(["Grad-CAM", "Grad-CAM++"], value="Grad-CAM", label="Grad-CAM Variant") | |
mc_samples = gr.Slider(1, 100, value=20, step=1, label="MC Dropout Samples") | |
dropout_p = gr.Slider(0.0, 0.9, value=0.1, step=0.05, label="Dropout Probability") | |
btn = gr.Button("Explain") | |
gc_fig = gr.Plot(label="Grad-CAM / Grad-CAM++") | |
gc_txt = gr.Textbox(label="Explanation (Grad-CAM)") | |
unc_fig = gr.Plot(label="Uncertainty (MC Dropout)") | |
unc_txt = gr.Textbox(label="Explanation (Uncertainty)") | |
# Visible control tooltips section (for environments where hovering tooltips are not available) | |
gr.Markdown( | |
""" | |
**Control tooltips (quick reference)** | |
- Confidence Threshold: Filter out detections with confidence below this value. | |
- Grad-CAM Variant: Choose the gradient-based visualization method. Grad-CAM++ may highlight smaller regions more precisely. | |
- MC Dropout Samples: Number of stochastic forward passes for uncertainty estimation. Increase for more stable results. | |
- Dropout Probability: Dropout rate used during MC Dropout sampling. Higher values typically increase predictive variance. | |
- Pick a detection: Select which detected object to explain. Format shown as 'index: label (score)'. | |
""" | |
) | |
# ---------- Key interpretability choices (Feynman-style) ---------- | |
gr.Markdown( | |
""" | |
**Key interpretability choices & why they matter** | |
- **Baseline (Integrated Gradients)**: Defines what 'no signal' looks like. Black (zeros) is simple, but blurred or neutral baselines may give more meaningful attributions. | |
- **Which conv layer for Grad-CAM**: Early layers give fine texture but low semantics; very late layers are coarse. A late backbone conv (default used) is a good compromise. | |
- **Number of MC Dropout samples**: More samples = smoother, more stable uncertainty estimates, but higher compute cost. | |
- **Grad-CAM vs Grad-CAM++**: Grad-CAM++ can be sharper and better for overlapping instances; vanilla Grad-CAM is faster and simpler. | |
""" | |
) | |
# ---------- Further reading / Feynman-style references ---------- | |
# Add short, clickable references so users can read the original papers and deep-dive articles. | |
gr.Markdown( | |
""" | |
**Further reading (recommended)** | |
- [Grad-CAM β Selvaraju et al., 2017 (arXiv)](https://arxiv.org/abs/1610.02391) β the original Grad-CAM paper; explains the core idea of gradient-weighted localization. | |
- [Grad-CAM++ β Chattopadhay et al.](https://arxiv.org/abs/1710.11063) β an improved variant that often produces sharper maps and handles multiple instances better. | |
- [Visualizing the Impact of Feature Attribution Baselines (Distill)](https://distill.pub/2020/attribution-baselines) β an accessible deep dive on baseline choices for Integrated Gradients. | |
- [Captum docs β IntegratedGradients](https://captum.ai/api/integrated_gradients.html) β practical API notes for baseline, n_steps, and convergence delta. | |
- [Constructing sensible baselines for Integrated Gradients](https://arxiv.org/abs/2004.09627) β discussion and techniques for choosing baselines beyond a black image. | |
- [A New Baseline Assumption of Integrated Gradients Based on Shapley Values](https://arxiv.org/html/2310.04821v3) β recent research on improved baselines. | |
""" | |
) | |
# Helper: safe label getter in case model.config.id2label is missing or not a dict | |
def safe_label_lookup(idx): | |
try: | |
id2label = getattr(model.config, 'id2label', None) | |
if id2label is None: | |
return f"Class {idx}" | |
return id2label.get(int(idx), f"Class {idx}") | |
except Exception: | |
return f"Class {idx}" | |
def run_detect(img, conf_thresh): | |
logger.info(f"Running detection with confidence threshold: {conf_thresh}") | |
try: | |
inputs = extractor(images=img, return_tensors="pt").to(device) | |
with torch.no_grad(): outputs = model(**inputs) | |
target_sizes = [img.size[::-1]] | |
results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.0)[0] | |
keep = results["scores"] > conf_thresh | |
boxes, labels, scores = results["boxes"][keep], results["labels"][keep], results["scores"][keep] | |
logger.info(f"Detection found {len(labels)} objects above threshold") | |
det_list = [f"{i}: {safe_label_lookup(l.item())} ({s:.2f})" for i,(l,s) in enumerate(zip(labels,scores))] | |
fig, ax = plt.subplots(); ax.imshow(img); ax.axis("off") | |
for box,label,score in zip(boxes,labels,scores): | |
xmin,ymin,xmax,ymax = box | |
ax.add_patch(patches.Rectangle((xmin,ymin),xmax-xmin,ymax-ymin,fill=False,color="red",lw=2)) | |
ax.text(xmin,ymin,f"{safe_label_lookup(label.item())}:{score:.2f}",color="black", | |
bbox=dict(facecolor="yellow",alpha=0.5)) | |
plt.close(fig) | |
default_val = det_list[0] if len(det_list) > 0 else None | |
logger.debug("Detection visualization created") | |
return {det_out: str(det_list), det_fig: fig, det_choice: gr.update(choices=det_list, value=default_val)} | |
except Exception as e: | |
logger.error(f"Error in run_detect: {str(e)}", exc_info=True) | |
return {det_out: "Error in detection", det_fig: None, det_choice: gr.update(choices=[], value=None)} | |
img_in.change(run_detect, inputs=[img_in, conf_thresh], outputs=[det_out, det_fig, det_choice]) | |
btn.click(interpret, inputs=[img_in, det_choice, conf_thresh, cam_variant, mc_samples, dropout_p], | |
outputs=[gc_fig, gc_txt, unc_fig, unc_txt]) | |
logger.info("Gradio interface configured, launching demo") | |
demo.launch() | |