Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torch.nn import GroupNorm, LayerNorm | |
import torch.nn.functional as F | |
import torch.utils.checkpoint as checkpoint | |
import timm | |
class ViTWrapper(nn.Module): | |
"""Wrapper to make ViT compatible with feature extraction for ImageTagger""" | |
def __init__(self, vit_model): | |
super().__init__() | |
self.vit = vit_model | |
self.out_indices = (-1,) # mimic timm.features_only | |
# Get patch size and embedding dim from the model | |
self.patch_size = vit_model.patch_embed.patch_size[0] | |
self.embed_dim = vit_model.embed_dim | |
def forward(self, x): | |
B = x.size(0) | |
# β patch tokens | |
x = self.vit.patch_embed(x) # (B, N, C) | |
# β prepend CLS | |
cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C) | |
x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C) | |
# β add positional encodings (full, incl. CLS) | |
if self.vit.pos_embed is not None: | |
x = x + self.vit.pos_embed[:, : x.size(1), :] | |
x = self.vit.pos_drop(x) | |
for blk in self.vit.blocks: | |
x = blk(x) | |
x = self.vit.norm(x) # (B, 1+N, C) | |
# β split back out | |
cls_final = x[:, 0] # (B, C) | |
patch_tokens = x[:, 1:] # (B, N, C) | |
# β reshape patches to (B, C, H, W) | |
B, N, C = patch_tokens.shape | |
h = w = int(N ** 0.5) # square assumption | |
patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w) | |
# Return **both**: (patch map, CLS) | |
return patch_features, cls_final | |
def set_grad_checkpointing(self, enable=True): | |
"""Enable gradient checkpointing if supported""" | |
if hasattr(self.vit, 'set_grad_checkpointing'): | |
self.vit.set_grad_checkpointing(enable) | |
return True | |
return False | |
class ImageTagger(nn.Module): | |
""" | |
ImageTagger with Vision Transformer backbone | |
""" | |
def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224', | |
num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256, | |
use_gradient_checkpointing=False, img_size=224): | |
super().__init__() | |
# Store checkpointing config | |
self.use_gradient_checkpointing = use_gradient_checkpointing | |
self.model_name = model_name | |
self.img_size = img_size | |
# Debug and stats flags | |
self._flags = { | |
'debug': False, | |
'model_stats': True | |
} | |
# Core model config | |
self.dataset = dataset | |
self.tag_context_size = tag_context_size | |
self.total_tags = total_tags | |
print(f"ποΈ Building ImageTagger with ViT backbone and {total_tags} tags") | |
print(f" Backbone: {model_name}") | |
print(f" Image size: {img_size}x{img_size}") | |
print(f" Tag context size: {tag_context_size}") | |
print(f" Gradient checkpointing: {use_gradient_checkpointing}") | |
print(f" π― Custom embeddings, PyTorch native attention, no ground truth inclusion") | |
# 1. Vision Transformer Backbone | |
print("π¦ Loading Vision Transformer backbone...") | |
self._load_vit_backbone() | |
# Get backbone dimensions by running a test forward pass | |
self._determine_backbone_dimensions() | |
self.embedding_dim = self.backbone.embed_dim | |
# 2. Custom Tag Embeddings (no CLIP) | |
print("π― Using custom tag embeddings (no CLIP)") | |
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim) | |
# 3. Shared weights approach - tag bias for initial predictions | |
print("π Using shared weights between initial head and tag embeddings") | |
self.tag_bias = nn.Parameter(torch.zeros(total_tags)) | |
# 4. Image token extraction (for attention AND global pooling) | |
self.image_token_proj = nn.Identity() | |
# 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation) | |
self.cross_attention = nn.MultiheadAttention( | |
embed_dim=self.embedding_dim, | |
num_heads=num_heads, | |
dropout=dropout, | |
batch_first=True # Use (batch, seq, feature) format | |
) | |
self.cross_norm = nn.LayerNorm(self.embedding_dim) | |
# Initialize weights | |
self._init_weights() | |
# Enable gradient checkpointing | |
if self.use_gradient_checkpointing: | |
self._enable_gradient_checkpointing() | |
print(f"β ImageTagger with ViT initialized!") | |
self._print_parameter_count() | |
def _load_vit_backbone(self): | |
"""Load Vision Transformer model from timm""" | |
print(f" Loading from timm: {self.model_name}") | |
# Load the ViT model (not features_only, we want the full model for token extraction) | |
vit_model = timm.create_model( | |
self.model_name, | |
pretrained=True, | |
img_size=self.img_size, | |
num_classes=0 # Remove classification head | |
) | |
# Wrap it in our compatibility layer | |
self.backbone = ViTWrapper(vit_model) | |
print(f" β ViT loaded successfully") | |
print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}") | |
print(f" Embed dim: {self.backbone.embed_dim}") | |
def _determine_backbone_dimensions(self): | |
"""Determine backbone output dimensions""" | |
print(" π Determining backbone dimensions...") | |
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): | |
# Create a dummy input | |
dummy_input = torch.randn(1, 3, self.img_size, self.img_size) | |
# Get features | |
backbone_features, cls_dummy = self.backbone(dummy_input) | |
feature_tensor = backbone_features | |
self.backbone_dim = feature_tensor.shape[1] | |
self.feature_map_size = feature_tensor.shape[2] | |
print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial") | |
print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}") | |
def _enable_gradient_checkpointing(self): | |
"""Enable gradient checkpointing for memory efficiency""" | |
print("π Enabling gradient checkpointing...") | |
# Enable checkpointing for ViT backbone | |
if self.backbone.set_grad_checkpointing(True): | |
print(" β ViT backbone checkpointing enabled") | |
else: | |
print(" β οΈ ViT backbone doesn't support built-in checkpointing, will checkpoint manually") | |
def _checkpoint_backbone(self, x): | |
"""Wrapper for backbone with gradient checkpointing""" | |
if self.use_gradient_checkpointing and self.training: | |
return checkpoint.checkpoint(self.backbone, x, use_reentrant=False) | |
else: | |
return self.backbone(x) | |
def _checkpoint_image_proj(self, x): | |
"""Wrapper for image projection with gradient checkpointing""" | |
if self.use_gradient_checkpointing and self.training: | |
return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False) | |
else: | |
return self.image_token_proj(x) | |
def _checkpoint_cross_attention(self, query, key, value): | |
"""Wrapper for cross attention with gradient checkpointing""" | |
def _attention_forward(q, k, v): | |
attended_features, _ = self.cross_attention(query=q, key=k, value=v) | |
return self.cross_norm(attended_features) | |
if self.use_gradient_checkpointing and self.training: | |
return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False) | |
else: | |
return _attention_forward(query, key, value) | |
def _checkpoint_candidate_selection(self, initial_logits): | |
"""Wrapper for candidate selection with gradient checkpointing""" | |
def _candidate_forward(logits): | |
return self._get_candidate_tags(logits) | |
if self.use_gradient_checkpointing and self.training: | |
return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False) | |
else: | |
return _candidate_forward(initial_logits) | |
def _checkpoint_final_scoring(self, attended_features, candidate_indices): | |
"""Wrapper for final scoring with gradient checkpointing""" | |
def _scoring_forward(features, indices): | |
emb = self.tag_embedding(indices) | |
# BF16 in, BF16 out | |
return (features * emb).sum(dim=-1) | |
if self.use_gradient_checkpointing and self.training: | |
return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False) | |
else: | |
return _scoring_forward(attended_features, candidate_indices) | |
def _init_weights(self): | |
"""Initialize weights for new modules""" | |
def _init_layer(layer): | |
if isinstance(layer, nn.Linear): | |
nn.init.xavier_uniform_(layer.weight) | |
if layer.bias is not None: | |
nn.init.zeros_(layer.bias) | |
elif isinstance(layer, nn.Conv2d): | |
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') | |
if layer.bias is not None: | |
nn.init.zeros_(layer.bias) | |
elif isinstance(layer, nn.Embedding): | |
nn.init.normal_(layer.weight, mean=0, std=0.02) | |
# Initialize new components | |
self.image_token_proj.apply(_init_layer) | |
# Initialize tag embeddings with normal distribution | |
nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02) | |
# Initialize tag bias | |
nn.init.zeros_(self.tag_bias) | |
def _print_parameter_count(self): | |
"""Print parameter statistics""" | |
total_params = sum(p.numel() for p in self.parameters()) | |
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
backbone_params = sum(p.numel() for p in self.backbone.parameters()) | |
print(f"π Parameter Statistics:") | |
print(f" Total parameters: {total_params/1e6:.1f}M") | |
print(f" Trainable parameters: {trainable_params/1e6:.1f}M") | |
print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M") | |
print(f" Backbone parameters: {backbone_params/1e6:.1f}M") | |
if self.use_gradient_checkpointing: | |
print(f" π Gradient checkpointing enabled for memory efficiency") | |
def debug(self): | |
return self._flags['debug'] | |
def model_stats(self): | |
return self._flags['model_stats'] | |
def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None): | |
"""Select candidate tags - no ground truth inclusion""" | |
batch_size = initial_logits.size(0) | |
# Simply select top K candidates based on initial predictions | |
top_probs, top_indices = torch.topk( | |
torch.sigmoid(initial_logits), | |
k=min(self.tag_context_size, self.total_tags), | |
dim=1, largest=True, sorted=True | |
) | |
return top_indices | |
def _analyze_predictions(self, predictions, tag_indices): | |
"""Analyze prediction patterns""" | |
if not self.model_stats: | |
return {} | |
if torch._dynamo.is_compiling(): | |
return {} | |
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): | |
probs = torch.sigmoid(predictions) | |
relevant_probs = torch.gather(probs, 1, tag_indices) | |
return { | |
'prediction_confidence': relevant_probs.mean().item(), | |
'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(), | |
'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(), | |
'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(), | |
} | |
def forward(self, x, targets=None, hard_negatives=None): | |
""" | |
Forward pass with ViT backbone, CLS token support and gradient-checkpointing. | |
All arithmetic tensors stay in the backboneβs dtype (BF16 under autocast, | |
FP32 otherwise). Anything that must mix dtypes is cast to match. | |
""" | |
batch_size = x.size(0) | |
model_stats = {} if self.model_stats else {} | |
# ------------------------------------------------------------------ | |
# 1. Backbone β patch map + CLS token | |
# ------------------------------------------------------------------ | |
patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W] | |
# cls_token: [B, C] | |
# ------------------------------------------------------------------ | |
# 2. Tokens β global image vector | |
# ------------------------------------------------------------------ | |
image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W] | |
image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C] | |
# βDual-poolβ: mean-pool patches β CLS | |
global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C] | |
compute_dtype = global_features.dtype # BF16 or FP32 | |
# ------------------------------------------------------------------ | |
# 3. Initial logits (shared weights) | |
# ------------------------------------------------------------------ | |
tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C] | |
tag_bias = self.tag_bias.to(compute_dtype) # [T] | |
initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T] | |
initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform | |
initial_preds = initial_logits # alias | |
# ------------------------------------------------------------------ | |
# 4. Candidate set | |
# ------------------------------------------------------------------ | |
candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K] | |
tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C] | |
attended_features = self._checkpoint_cross_attention( # [B, K, C] | |
tag_embeddings, image_tokens, image_tokens | |
) | |
# ------------------------------------------------------------------ | |
# 5. Score candidates & scatter back | |
# ------------------------------------------------------------------ | |
candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K] | |
# --- align dtypes so scatter never throws --- | |
if candidate_logits.dtype != initial_logits.dtype: | |
candidate_logits = candidate_logits.to(initial_logits.dtype) | |
refined_logits = initial_logits.clone() | |
refined_logits.scatter_(1, candidate_indices, candidate_logits) | |
refined_preds = refined_logits | |
# ------------------------------------------------------------------ | |
# 6. Optional stats | |
# ------------------------------------------------------------------ | |
if self.model_stats and targets is not None and not torch._dynamo.is_compiling(): | |
model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds, | |
candidate_indices) | |
model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds, | |
candidate_indices) | |
return { | |
'initial_predictions': initial_preds, | |
'refined_predictions': refined_preds, | |
'selected_candidates': candidate_indices, | |
'model_stats': model_stats | |
} |