Spaces:
Runtime error
Runtime error
import math | |
import types | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.transforms import Compose, Resize, InterpolationMode | |
import open_clip | |
from open_clip.transformer import VisionTransformer | |
from open_clip.timm_model import TimmModel | |
from einops import rearrange | |
from .utils import ( | |
hooked_attention_timm_forward, | |
hooked_resblock_forward, | |
hooked_attention_forward, | |
hooked_resblock_timm_forward, | |
hooked_attentional_pooler_timm_forward, | |
vit_dynamic_size_forward, | |
min_max, | |
hooked_torch_multi_head_attention_forward, | |
) | |
class LeWrapper(nn.Module): | |
""" | |
Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model. | |
""" | |
def __init__(self, model, layer_index=-2): | |
super(LeWrapper, self).__init__() | |
# ------------ copy of model's attributes and methods ------------ | |
for attr in dir(model): | |
if not attr.startswith("__"): | |
setattr(self, attr, getattr(model, attr)) | |
# ------------ activate hooks & gradient ------------ | |
self._activate_hooks(layer_index=layer_index) | |
def _activate_hooks(self, layer_index): | |
# ------------ identify model's type ------------ | |
print("Activating necessary hooks and gradients ....") | |
if isinstance(self.visual, VisionTransformer): | |
# --- Activate dynamic image size --- | |
self.visual.forward = types.MethodType( | |
vit_dynamic_size_forward, self.visual | |
) | |
# Get patch size | |
self.patch_size = self.visual.patch_size[0] | |
# Get starting depth (in case of negative layer_index) | |
self.starting_depth = ( | |
layer_index | |
if layer_index >= 0 | |
else len(self.visual.transformer.resblocks) + layer_index | |
) | |
if self.visual.attn_pool is None: | |
self.model_type = "clip" | |
self._activate_self_attention_hooks() | |
else: | |
self.model_type = "coca" | |
self._activate_att_pool_hooks(layer_index=layer_index) | |
elif isinstance(self.visual, TimmModel): | |
# --- Activate dynamic image size --- | |
self.visual.trunk.dynamic_img_size = True | |
self.visual.trunk.patch_embed.dynamic_img_size = True | |
self.visual.trunk.patch_embed.strict_img_size = False | |
self.visual.trunk.patch_embed.flatten = False | |
self.visual.trunk.patch_embed.output_fmt = "NHWC" | |
self.model_type = "timm_siglip" | |
# --- Get patch size --- | |
self.patch_size = self.visual.trunk.patch_embed.patch_size[0] | |
# --- Get starting depth (in case of negative layer_index) --- | |
self.starting_depth = ( | |
layer_index | |
if layer_index >= 0 | |
else len(self.visual.trunk.blocks) + layer_index | |
) | |
if ( | |
hasattr(self.visual.trunk, "attn_pool") | |
and self.visual.trunk.attn_pool is not None | |
): | |
self._activate_timm_attn_pool_hooks(layer_index=layer_index) | |
else: | |
self._activate_timm_self_attention_hooks() | |
else: | |
raise ValueError( | |
"Model currently not supported, see legrad.list_pretrained() for a list of available models" | |
) | |
print("Hooks and gradients activated!") | |
def _activate_self_attention_hooks(self): | |
# Adjusting to use the correct structure | |
if isinstance(self.visual, VisionTransformer): | |
blocks = self.visual.transformer.resblocks | |
elif isinstance(self.visual, TimmModel): | |
blocks = self.visual.trunk.blocks | |
else: | |
raise ValueError("Unsupported model type for self-attention hooks") | |
# ---------- Apply Hooks + Activate/Deactivate gradients ---------- | |
# Necessary steps to get intermediate representations | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
if name.startswith("visual.trunk.blocks"): | |
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) | |
if depth >= self.starting_depth: | |
param.requires_grad = True | |
# --- Activate the hooks for the specific layers --- | |
for layer in range(self.starting_depth, len(blocks)): | |
blocks[layer].attn.forward = types.MethodType( | |
hooked_attention_forward, blocks[layer].attn | |
) | |
blocks[layer].forward = types.MethodType( | |
hooked_resblock_forward, blocks[layer] | |
) | |
def _activate_timm_self_attention_hooks(self): | |
# Adjusting to use the correct structure | |
blocks = self.visual.trunk.blocks | |
# ---------- Apply Hooks + Activate/Deactivate gradients ---------- | |
# Necessary steps to get intermediate representations | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
if name.startswith("visual.trunk.blocks"): | |
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) | |
if depth >= self.starting_depth: | |
param.requires_grad = True | |
# --- Activate the hooks for the specific layers --- | |
for layer in range(self.starting_depth, len(blocks)): | |
blocks[layer].attn.forward = types.MethodType( | |
hooked_attention_timm_forward, blocks[layer].attn | |
) | |
blocks[layer].forward = types.MethodType( | |
hooked_resblock_timm_forward, blocks[layer] | |
) | |
def _activate_att_pool_hooks(self, layer_index): | |
# ---------- Apply Hooks + Activate/Deactivate gradients ---------- | |
# Necessary steps to get intermediate representations | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
if name.startswith("visual.transformer.resblocks"): | |
# get the depth | |
depth = int( | |
name.split("visual.transformer.resblocks.")[-1].split(".")[0] | |
) | |
if depth >= self.starting_depth: | |
param.requires_grad = True | |
# --- Activate the hooks for the specific layers --- | |
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): | |
self.visual.transformer.resblocks[layer].forward = types.MethodType( | |
hooked_resblock_forward, self.visual.transformer.resblocks[layer] | |
) | |
# --- Apply hook on the attentional pooler --- | |
self.visual.attn_pool.attn.forward = types.MethodType( | |
hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn | |
) | |
def _activate_timm_attn_pool_hooks(self, layer_index): | |
# Ensure all components are present before attaching hooks | |
if ( | |
not hasattr(self.visual.trunk, "attn_pool") | |
or self.visual.trunk.attn_pool is None | |
): | |
raise ValueError("Attentional pooling not found in TimmModel") | |
self.visual.trunk.attn_pool.forward = types.MethodType( | |
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool | |
) | |
for block in self.visual.trunk.blocks: | |
if hasattr(block, "attn"): | |
block.attn.forward = types.MethodType( | |
hooked_attention_forward, block.attn | |
) | |
# --- Deactivate gradient for module that don't need it --- | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
if name.startswith("visual.trunk.attn_pool"): | |
param.requires_grad = True | |
if name.startswith("visual.trunk.blocks"): | |
# get the depth | |
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0]) | |
if depth >= self.starting_depth: | |
param.requires_grad = True | |
# --- Activate the hooks for the specific layers by modifying the block's forward --- | |
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): | |
self.visual.trunk.blocks[layer].forward = types.MethodType( | |
hooked_resblock_timm_forward, self.visual.trunk.blocks[layer] | |
) | |
self.visual.trunk.attn_pool.forward = types.MethodType( | |
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool | |
) | |
def compute_legrad(self, text_embedding, image=None, apply_correction=True): | |
if "clip" in self.model_type: | |
return self.compute_legrad_clip(text_embedding, image) | |
elif "siglip" in self.model_type: | |
return self.compute_legrad_siglip( | |
text_embedding, image, apply_correction=apply_correction | |
) | |
elif "coca" in self.model_type: | |
return self.compute_legrad_coca(text_embedding, image) | |
def compute_legrad_clip(self, text_embedding, image=None): | |
num_prompts = text_embedding.shape[0] | |
if image is not None: | |
# Ensure the image is passed through the model to get the intermediate features | |
_ = self.encode_image(image) | |
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) | |
image_features_list = [] | |
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)): | |
# [num_patch, batch, dim] | |
intermediate_feat = blocks_list[layer].feat_post_mlp | |
# Mean over the patch tokens | |
intermediate_feat = intermediate_feat.mean(dim=1) | |
intermediate_feat = self.visual.head( | |
self.visual.trunk.norm(intermediate_feat) | |
) | |
intermediate_feat = F.normalize(intermediate_feat, dim=-1) | |
image_features_list.append(intermediate_feat) | |
num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1 | |
w = h = int(math.sqrt(num_tokens)) | |
# ----- Get explainability map | |
accum_expl_map = 0 | |
for layer, (blk, img_feat) in enumerate( | |
zip(blocks_list[self.starting_depth :], image_features_list) | |
): | |
self.visual.zero_grad() | |
sim = text_embedding @ img_feat.transpose(-1, -2) # [1, 1] | |
one_hot = ( | |
F.one_hot(torch.arange(0, num_prompts)) | |
.float() | |
.requires_grad_(True) | |
.to(text_embedding.device) | |
) | |
one_hot = torch.sum(one_hot * sim) | |
# [b, num_heads, N, N] | |
attn_map = blocks_list[self.starting_depth + layer].attn.attention_map | |
# -------- Get explainability map -------- | |
# [batch_size * num_heads, N, N] | |
grad = torch.autograd.grad( | |
one_hot, [attn_map], retain_graph=True, create_graph=True | |
)[0] | |
# grad = rearrange(grad, '(b h) n m -> b h n m', b=num_prompts) # separate batch and attn heads | |
grad = torch.clamp(grad, min=0.0) | |
# average attn over [CLS] + patch tokens | |
image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:] | |
expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h) | |
# [B, 1, H, W] | |
expl_map = F.interpolate( | |
expl_map, scale_factor=self.patch_size, mode="bilinear" | |
) | |
accum_expl_map += expl_map | |
# Min-Max Norm | |
accum_expl_map = min_max(accum_expl_map) | |
return accum_expl_map | |
def compute_legrad_coca(self, text_embedding, image=None): | |
if image is not None: | |
_ = self.encode_image(image) | |
blocks_list = list( | |
dict(self.visual.transformer.resblocks.named_children()).values() | |
) | |
image_features_list = [] | |
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): | |
intermediate_feat = self.visual.transformer.resblocks[ | |
layer | |
].feat_post_mlp # [num_patch, batch, dim] | |
intermediate_feat = intermediate_feat.permute( | |
1, 0, 2 | |
) # [batch, num_patch, dim] | |
image_features_list.append(intermediate_feat) | |
num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1 | |
w = h = int(math.sqrt(num_tokens)) | |
# ----- Get explainability map | |
accum_expl_map = 0 | |
for layer, (blk, img_feat) in enumerate( | |
zip(blocks_list[self.starting_depth :], image_features_list) | |
): | |
self.visual.zero_grad() | |
# --- Apply attn_pool --- | |
image_embedding = self.visual.attn_pool(img_feat)[ | |
:, 0 | |
] # we keep only the first pooled token as it is only this one trained with the contrastive loss | |
image_embedding = image_embedding @ self.visual.proj | |
sim = text_embedding @ image_embedding.transpose(-1, -2) # [1, 1] | |
one_hot = torch.sum(sim) | |
attn_map = ( | |
self.visual.attn_pool.attn.attention_maps | |
) # [num_heads, num_latent, num_patch] | |
# -------- Get explainability map -------- | |
grad = torch.autograd.grad( | |
one_hot, [attn_map], retain_graph=True, create_graph=True | |
)[ | |
0 | |
] # [num_heads, num_latent, num_patch] | |
grad = torch.clamp(grad, min=0.0) | |
image_relevance = grad.mean(dim=0)[ | |
0, 1: | |
] # average attn over heads + select first latent | |
expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h) | |
expl_map = F.interpolate( | |
expl_map, scale_factor=self.patch_size, mode="bilinear" | |
) # [B, 1, H, W] | |
accum_expl_map += expl_map | |
# Min-Max Norm | |
accum_expl_map = (accum_expl_map - accum_expl_map.min()) / ( | |
accum_expl_map.max() - accum_expl_map.min() | |
) | |
return accum_expl_map | |
def _init_empty_embedding(self): | |
if not hasattr(self, "empty_embedding"): | |
# For the moment only SigLIP is supported & they all have the same tokenizer | |
_tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP") | |
empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device) | |
empty_embedding = self.encode_text(empty_text) | |
empty_embedding = F.normalize(empty_embedding, dim=-1) | |
self.empty_embedding = empty_embedding.t() | |
def compute_legrad_siglip( | |
self, | |
text_embedding, | |
image=None, | |
apply_correction=True, | |
correction_threshold=0.8, | |
): | |
# --- Forward CLIP --- | |
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values()) | |
if image is not None: | |
_ = self.encode_image(image) # [bs, num_patch, dim] bs=num_masks | |
image_features_list = [] | |
for blk in blocks_list[self.starting_depth :]: | |
intermediate_feat = blk.feat_post_mlp | |
image_features_list.append(intermediate_feat) | |
num_tokens = blocks_list[-1].feat_post_mlp.shape[1] | |
w = h = int(math.sqrt(num_tokens)) | |
if apply_correction: | |
self._init_empty_embedding() | |
accum_expl_map_empty = 0 | |
accum_expl_map = 0 | |
for layer, (blk, img_feat) in enumerate( | |
zip(blocks_list[self.starting_depth :], image_features_list) | |
): | |
self.zero_grad() | |
pooled_feat = self.visual.trunk.attn_pool(img_feat) | |
pooled_feat = F.normalize(pooled_feat, dim=-1) | |
# -------- Get explainability map -------- | |
sim = text_embedding @ pooled_feat.transpose(-1, -2) # [num_mask, num_mask] | |
one_hot = torch.sum(sim) | |
grad = torch.autograd.grad( | |
one_hot, | |
[self.visual.trunk.attn_pool.attn_probs], | |
retain_graph=True, | |
create_graph=True, | |
)[0] | |
grad = torch.clamp(grad, min=0.0) | |
image_relevance = grad.mean(dim=1)[ | |
:, 0 | |
] # average attn over [CLS] + patch tokens | |
expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h) | |
accum_expl_map += expl_map | |
if apply_correction: | |
# -------- Get empty explainability map -------- | |
sim_empty = pooled_feat @ self.empty_embedding | |
one_hot_empty = torch.sum(sim_empty) | |
grad_empty = torch.autograd.grad( | |
one_hot_empty, | |
[self.visual.trunk.attn_pool.attn_probs], | |
retain_graph=True, | |
create_graph=True, | |
)[0] | |
grad_empty = torch.clamp(grad_empty, min=0.0) | |
image_relevance_empty = grad_empty.mean(dim=1)[ | |
:, 0 | |
] # average attn over heads + select query's row | |
expl_map_empty = rearrange( | |
image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h | |
) | |
accum_expl_map_empty += expl_map_empty | |
if apply_correction: | |
heatmap_empty = min_max(accum_expl_map_empty) | |
accum_expl_map[heatmap_empty > correction_threshold] = 0 | |
Res = min_max(accum_expl_map) | |
Res = F.interpolate( | |
Res, scale_factor=self.patch_size, mode="bilinear" | |
) # [B, 1, H, W] | |
return Res | |
class LePreprocess(nn.Module): | |
""" | |
Modify OpenCLIP preprocessing to accept arbitrary image size. | |
""" | |
def __init__(self, preprocess, image_size): | |
super(LePreprocess, self).__init__() | |
self.transform = Compose( | |
[ | |
Resize( | |
(image_size, image_size), interpolation=InterpolationMode.BICUBIC | |
), | |
preprocess.transforms[-3], | |
preprocess.transforms[-2], | |
preprocess.transforms[-1], | |
] | |
) | |
def forward(self, image): | |
return self.transform(image) | |