Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (C) 2025 NVIDIA Corporation. All rights reserved. | |
# | |
# This work is licensed under the LICENSE file | |
# located at the root directory. | |
import torch | |
from skimage import filters | |
import cv2 | |
import torch.nn.functional as F | |
from skimage.filters import threshold_li, threshold_yen, threshold_multiotsu | |
import numpy as np | |
from visualization_utils import show_tensors | |
import matplotlib.pyplot as plt | |
def text_to_tokens(text, tokenizer): | |
return [tokenizer.decode(x) for x in tokenizer(text, padding="longest", return_tensors="pt").input_ids[0]] | |
def flatten_list(l): | |
return [item for sublist in l for item in sublist] | |
def gaussian_blur(heatmap, kernel_size=7, sigma=0): | |
# Shape of heatmap: (H, W) | |
heatmap = heatmap.cpu().numpy() | |
heatmap = cv2.GaussianBlur(heatmap, (kernel_size, kernel_size), sigma) | |
heatmap = torch.tensor(heatmap) | |
return heatmap | |
def min_max_norm(x): | |
return (x - x.min()) / (x.max() - x.min()) | |
class AttentionStore: | |
def __init__(self, prompts, tokenizer, | |
subject_token=None, record_attention_steps=[], | |
is_cache_attn_ratio=False, attn_ratios_steps=[5]): | |
self.text2image_store = {} | |
self.image2text_store = {} | |
self.count_per_layer = {} | |
self.record_attention_steps = record_attention_steps | |
self.record_attention_layers = ["transformer_blocks.13","transformer_blocks.14", "transformer_blocks.18", "single_transformer_blocks.23", "single_transformer_blocks.33"] | |
self.attention_ratios = {} | |
self._is_cache_attn_ratio = is_cache_attn_ratio | |
self.attn_ratios_steps = attn_ratios_steps | |
self.ratio_source = 'text' | |
self.max_tokens_to_record = 10 | |
if isinstance(prompts, str): | |
prompts = [prompts] | |
batch_size = 1 | |
else: | |
batch_size = len(prompts) | |
tokens_per_prompt = [] | |
for prompt in prompts: | |
tokens = text_to_tokens(prompt, tokenizer) | |
tokens_per_prompt.append(tokens) | |
self.tokens_to_record = [] | |
self.token_idxs_to_record = [] | |
if len(record_attention_steps) > 0: | |
self.subject_tokens = flatten_list([text_to_tokens(x, tokenizer)[:-1] for x in [subject_token]]) | |
self.subject_tokens_idx = [tokens_per_prompt[1].index(x) for x in self.subject_tokens] | |
self.add_token_idx = self.subject_tokens_idx[-1] | |
def is_record_attention(self, layer_name, step_index): | |
is_correct_layer = (self.record_attention_layers is None) or (layer_name in self.record_attention_layers) | |
record_attention = (step_index in self.record_attention_steps) and (is_correct_layer) | |
return record_attention | |
def store_attention(self, attention_probs, layer_name, batch_size, num_heads): | |
text_len = 512 | |
timesteps = len(self.record_attention_steps) | |
# Split batch and heads | |
attention_probs = attention_probs.view(batch_size, num_heads, *attention_probs.shape[1:]) | |
# Mean over the heads | |
attention_probs = attention_probs.mean(dim=1) | |
# Attention: text -> image | |
attention_probs_text2image = attention_probs[:, :text_len, text_len:] | |
attention_probs_text2image = [attention_probs_text2image[0, self.subject_tokens_idx, :]] | |
# Attention: image -> text | |
attention_probs_image2text = attention_probs[:, text_len:, :text_len].transpose(1,2) | |
attention_probs_image2text = [attention_probs_image2text[0, self.subject_tokens_idx, :]] | |
if layer_name not in self.text2image_store: | |
self.text2image_store[layer_name] = [x for x in attention_probs_text2image] | |
self.image2text_store[layer_name] = [x for x in attention_probs_image2text] | |
else: | |
self.text2image_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_text2image)] | |
self.image2text_store[layer_name] = [self.text2image_store[layer_name][i] + x for i, x in enumerate(attention_probs_image2text)] | |
def is_cache_attn_ratio(self, step_index): | |
return (self._is_cache_attn_ratio) and (step_index in self.attn_ratios_steps) | |
def store_attention_ratios(self, attention_probs, step_index, layer_name): | |
layer_prefix = layer_name.split(".")[0] | |
if self.ratio_source == 'pixels': | |
extended_attention_probs = attention_probs.mean(dim=0)[512:, :] | |
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=1).view(64,64).float().cpu() | |
extended_attention_probs_text = extended_attention_probs[:,4096:4096+512].sum(dim=1).view(64,64).float().cpu() | |
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=1).view(64,64).float().cpu() | |
token_attention = extended_attention_probs[:,4096+self.add_token_idx].view(64,64).float().cpu() | |
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_text, extended_attention_probs_target, token_attention], dim=1) | |
elif self.ratio_source == 'text': | |
extended_attention_probs = attention_probs.mean(dim=0)[:512, :] | |
extended_attention_probs_source = extended_attention_probs[:,:4096].sum(dim=0).view(64,64).float().cpu() | |
extended_attention_probs_target = extended_attention_probs[:,4096+512:].sum(dim=0).view(64,64).float().cpu() | |
stacked_attention_ratios = torch.cat([extended_attention_probs_source, extended_attention_probs_target], dim=1) | |
if step_index not in self.attention_ratios: | |
self.attention_ratios[step_index] = {} | |
if layer_prefix not in self.attention_ratios[step_index]: | |
self.attention_ratios[step_index][layer_prefix] = [] | |
self.attention_ratios[step_index][layer_prefix].append(stacked_attention_ratios) | |
def get_attention_ratios(self, step_indices=None, display_imgs=False): | |
ratios = [] | |
if step_indices is None: | |
step_indices = list(self.attention_ratios.keys()) | |
if len(step_indices) == 1: | |
steps = f"Step: {step_indices[0]}" | |
else: | |
steps = f"Steps: [{step_indices[0]}-{step_indices[-1]}]" | |
layer_prefixes = list(self.attention_ratios[step_indices[0]].keys()) | |
scores_per_layer = {} | |
for layer_prefix in layer_prefixes: | |
ratios = [] | |
for step_index in step_indices: | |
if layer_prefix in self.attention_ratios[step_index]: | |
step_ratios = self.attention_ratios[step_index][layer_prefix] | |
step_ratios = torch.stack(step_ratios).mean(dim=0) | |
ratios.append(step_ratios) | |
# Mean over the steps | |
ratios = torch.stack(ratios).mean(dim=0) | |
if self.ratio_source == 'pixels': | |
source, text, target, token = torch.split(ratios, 64, dim=1) | |
title = f"{steps}: Source={source.sum().item():.2f}, Text={text.sum().item():.2f}, Target={target.sum().item():.2f}, Token={token.sum().item():.2f}" | |
ratios = min_max_norm(torch.cat([source, text, target], dim=1)) | |
token = min_max_norm(token) | |
ratios = torch.cat([ratios, token], dim=1) | |
elif self.ratio_source == 'text': | |
source, target = torch.split(ratios, 64, dim=1) | |
source_sum = source.sum().item() | |
target_sum = target.sum().item() | |
text_sum = 512 - (source_sum + target_sum) | |
title = f"{steps}: Source={source_sum:.2f}, Target={target_sum:.2f}" | |
ratios = min_max_norm(torch.cat([source, target], dim=1)) | |
if display_imgs: | |
print(f"Layer: {layer_prefix}") | |
show_tensors([ratios], [title]) | |
scores_per_layer[layer_prefix] = (source_sum, text_sum, target_sum) | |
return scores_per_layer | |
def plot_attention_ratios(self, step_indices=None): | |
steps = list(self.attention_ratios.keys()) | |
score_per_layer = { | |
'transformer_blocks': {}, | |
'single_transformer_blocks': {} | |
} | |
for i in steps: | |
scores_per_layer = self.get_attention_ratios(step_indices=[i], display_imgs=False) | |
for layer in self.attention_ratios[i]: | |
source, text, target = scores_per_layer[layer] | |
score_per_layer[layer][i] = (source, text, target) | |
for layer_type in score_per_layer: | |
x = list(score_per_layer[layer_type].keys()) | |
source_sums = [x[0] for x in score_per_layer[layer_type].values()] | |
text_sums = [x[1] for x in score_per_layer[layer_type].values()] | |
target_sums = [x[2] for x in score_per_layer[layer_type].values()] | |
# Calculate the total sums for each stack (source + text + target) | |
total_sums = [source_sums[j] + text_sums[j] + target_sums[j] for j in range(len(source_sums))] | |
# Create stacked bar plots | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
indices = np.arange(len(x)) | |
# Plot source at the bottom | |
ax.bar(indices, source_sums, label='Source', color='#6A2C70') | |
# Plot text stacked on source | |
ax.bar(indices, text_sums, label='Text', color='#B83B5E', bottom=source_sums) | |
# Plot target stacked on text + source | |
target_bottom = [source_sums[j] + text_sums[j] for j in range(len(source_sums))] | |
ax.bar(indices, target_sums, label='Target', color='#F08A5D', bottom=target_bottom) | |
# Annotate bars with percentage values | |
for j, index in enumerate(indices): | |
font_size = 12 | |
# Source percentage | |
source_percentage = 100 * source_sums[j] / total_sums[j] | |
ax.text(index, source_sums[j] / 2, f'{source_percentage:.1f}%', | |
ha='center', va='center', rotation=90, color='white', | |
fontsize=font_size, fontweight='bold') | |
# Text percentage | |
text_percentage = 100 * text_sums[j] / total_sums[j] | |
ax.text(index, source_sums[j] + (text_sums[j] / 2), f'{text_percentage:.1f}%', | |
ha='center', va='center', rotation=90, color='white', | |
fontsize=font_size, fontweight='bold') | |
# Target percentage | |
target_percentage = 100 * target_sums[j] / total_sums[j] | |
ax.text(index, source_sums[j] + text_sums[j] + (target_sums[j] / 2), f'{target_percentage:.1f}%', | |
ha='center', va='center', rotation=90, color='white', | |
fontsize=font_size, fontweight='bold') | |
ax.set_xlabel('Step Index') | |
ax.set_ylabel('Attention Ratio') | |
ax.set_title(f'Attention Ratios for {layer_type}') | |
ax.set_xticks(indices) | |
ax.set_xticklabels(x) | |
plt.legend() | |
plt.show() | |
def aggregate_attention(self, store, target_layers=None, resolution=None, | |
gaussian_kernel=3, thr_type='otsu', thr_number=0.5): | |
if target_layers is None: | |
store_vals = list(store.values()) | |
elif isinstance(target_layers, list): | |
store_vals = [store[x] for x in target_layers] | |
else: | |
raise ValueError("target_layers must be a list of layer names or None.") | |
# store vals = List[layers] of Tensor[batch_size, text_tokens, image_tokens] | |
batch_size = len(store_vals[0]) | |
attention_maps = [] | |
attention_masks = [] | |
for i in range(batch_size): | |
# Average over the layers | |
agg_vals = torch.stack([x[i] for x in store_vals]).mean(dim=0) | |
if resolution is None: | |
size = int(agg_vals.shape[-1] ** 0.5) | |
resolution = (size, size) | |
agg_vals = agg_vals.view(agg_vals.shape[0], *resolution) | |
if gaussian_kernel > 0: | |
agg_vals = torch.stack([gaussian_blur(x.float(), kernel_size=gaussian_kernel) for x in agg_vals]).to(agg_vals.dtype) | |
mask_vals = agg_vals.clone() | |
for j in range(mask_vals.shape[0]): | |
mask_vals[j] = (mask_vals[j] - mask_vals[j].min()) / (mask_vals[j].max() - mask_vals[j].min()) | |
np_vals = mask_vals[j].float().cpu().numpy() | |
otsu_thr = filters.threshold_otsu(np_vals) | |
li_thr = threshold_li(np_vals, initial_guess=otsu_thr) | |
yen_thr = threshold_yen(np_vals) | |
if thr_type == 'otsu': | |
thr = otsu_thr | |
elif thr_type == 'yen': | |
thr = yen_thr | |
elif thr_type == 'li': | |
thr = li_thr | |
elif thr_type == 'number': | |
thr = thr_number | |
elif thr_type == 'multiotsu': | |
thrs = threshold_multiotsu(np_vals, classes=3) | |
if thrs[1] > thrs[0] * 3.5: | |
thr = thrs[1] | |
else: | |
thr = thrs[0] | |
# Take the closest threshold to otsu_thr | |
# thr = thrs[np.argmin(np.abs(thrs - otsu_thr))] | |
# alpha = 0.8 | |
# thr = (alpha * thr + (1-alpha) * mask_vals[j].max()) | |
mask_vals[j] = (mask_vals[j] > thr).to(mask_vals[j].dtype) | |
attention_maps.append(agg_vals) | |
attention_masks.append(mask_vals) | |
return attention_maps, attention_masks, self.tokens_to_record | |