import torch import torch.nn as nn from torch.nn import GroupNorm, LayerNorm import torch.nn.functional as F import torch.utils.checkpoint as checkpoint import timm class ViTWrapper(nn.Module): """Wrapper to make ViT compatible with feature extraction for ImageTagger""" def __init__(self, vit_model): super().__init__() self.vit = vit_model self.out_indices = (-1,) # mimic timm.features_only # Get patch size and embedding dim from the model self.patch_size = vit_model.patch_embed.patch_size[0] self.embed_dim = vit_model.embed_dim def forward(self, x): B = x.size(0) # ➊ patch tokens x = self.vit.patch_embed(x) # (B, N, C) # ➋ prepend CLS cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C) x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C) # ➌ add positional encodings (full, incl. CLS) if self.vit.pos_embed is not None: x = x + self.vit.pos_embed[:, : x.size(1), :] x = self.vit.pos_drop(x) for blk in self.vit.blocks: x = blk(x) x = self.vit.norm(x) # (B, 1+N, C) # ➍ split back out cls_final = x[:, 0] # (B, C) patch_tokens = x[:, 1:] # (B, N, C) # ➎ reshape patches to (B, C, H, W) B, N, C = patch_tokens.shape h = w = int(N ** 0.5) # square assumption patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w) # Return **both**: (patch map, CLS) return patch_features, cls_final def set_grad_checkpointing(self, enable=True): """Enable gradient checkpointing if supported""" if hasattr(self.vit, 'set_grad_checkpointing'): self.vit.set_grad_checkpointing(enable) return True return False class ImageTagger(nn.Module): """ ImageTagger with Vision Transformer backbone """ def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224', num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256, use_gradient_checkpointing=False, img_size=224): super().__init__() # Store checkpointing config self.use_gradient_checkpointing = use_gradient_checkpointing self.model_name = model_name self.img_size = img_size # Debug and stats flags self._flags = { 'debug': False, 'model_stats': True } # Core model config self.dataset = dataset self.tag_context_size = tag_context_size self.total_tags = total_tags print(f"🏗️ Building ImageTagger with ViT backbone and {total_tags} tags") print(f" Backbone: {model_name}") print(f" Image size: {img_size}x{img_size}") print(f" Tag context size: {tag_context_size}") print(f" Gradient checkpointing: {use_gradient_checkpointing}") print(f" 🎯 Custom embeddings, PyTorch native attention, no ground truth inclusion") # 1. Vision Transformer Backbone print("📦 Loading Vision Transformer backbone...") self._load_vit_backbone() # Get backbone dimensions by running a test forward pass self._determine_backbone_dimensions() self.embedding_dim = self.backbone.embed_dim # 2. Custom Tag Embeddings (no CLIP) print("🎯 Using custom tag embeddings (no CLIP)") self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim) # 3. Shared weights approach - tag bias for initial predictions print("🔗 Using shared weights between initial head and tag embeddings") self.tag_bias = nn.Parameter(torch.zeros(total_tags)) # 4. Image token extraction (for attention AND global pooling) self.image_token_proj = nn.Identity() # 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation) self.cross_attention = nn.MultiheadAttention( embed_dim=self.embedding_dim, num_heads=num_heads, dropout=dropout, batch_first=True # Use (batch, seq, feature) format ) self.cross_norm = nn.LayerNorm(self.embedding_dim) # Initialize weights self._init_weights() # Enable gradient checkpointing if self.use_gradient_checkpointing: self._enable_gradient_checkpointing() print(f"✅ ImageTagger with ViT initialized!") self._print_parameter_count() def _load_vit_backbone(self): """Load Vision Transformer model from timm""" print(f" Loading from timm: {self.model_name}") # Load the ViT model (not features_only, we want the full model for token extraction) vit_model = timm.create_model( self.model_name, pretrained=True, img_size=self.img_size, num_classes=0 # Remove classification head ) # Wrap it in our compatibility layer self.backbone = ViTWrapper(vit_model) print(f" ✅ ViT loaded successfully") print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}") print(f" Embed dim: {self.backbone.embed_dim}") def _determine_backbone_dimensions(self): """Determine backbone output dimensions""" print(" 🔍 Determining backbone dimensions...") with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): # Create a dummy input dummy_input = torch.randn(1, 3, self.img_size, self.img_size) # Get features backbone_features, cls_dummy = self.backbone(dummy_input) feature_tensor = backbone_features self.backbone_dim = feature_tensor.shape[1] self.feature_map_size = feature_tensor.shape[2] print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial") print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}") def _enable_gradient_checkpointing(self): """Enable gradient checkpointing for memory efficiency""" print("🔄 Enabling gradient checkpointing...") # Enable checkpointing for ViT backbone if self.backbone.set_grad_checkpointing(True): print(" ✅ ViT backbone checkpointing enabled") else: print(" ⚠️ ViT backbone doesn't support built-in checkpointing, will checkpoint manually") def _checkpoint_backbone(self, x): """Wrapper for backbone with gradient checkpointing""" if self.use_gradient_checkpointing and self.training: return checkpoint.checkpoint(self.backbone, x, use_reentrant=False) else: return self.backbone(x) def _checkpoint_image_proj(self, x): """Wrapper for image projection with gradient checkpointing""" if self.use_gradient_checkpointing and self.training: return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False) else: return self.image_token_proj(x) def _checkpoint_cross_attention(self, query, key, value): """Wrapper for cross attention with gradient checkpointing""" def _attention_forward(q, k, v): attended_features, _ = self.cross_attention(query=q, key=k, value=v) return self.cross_norm(attended_features) if self.use_gradient_checkpointing and self.training: return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False) else: return _attention_forward(query, key, value) def _checkpoint_candidate_selection(self, initial_logits): """Wrapper for candidate selection with gradient checkpointing""" def _candidate_forward(logits): return self._get_candidate_tags(logits) if self.use_gradient_checkpointing and self.training: return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False) else: return _candidate_forward(initial_logits) def _checkpoint_final_scoring(self, attended_features, candidate_indices): """Wrapper for final scoring with gradient checkpointing""" def _scoring_forward(features, indices): emb = self.tag_embedding(indices) # BF16 in, BF16 out return (features * emb).sum(dim=-1) if self.use_gradient_checkpointing and self.training: return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False) else: return _scoring_forward(attended_features, candidate_indices) def _init_weights(self): """Initialize weights for new modules""" def _init_layer(layer): if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.zeros_(layer.bias) elif isinstance(layer, nn.Conv2d): nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu') if layer.bias is not None: nn.init.zeros_(layer.bias) elif isinstance(layer, nn.Embedding): nn.init.normal_(layer.weight, mean=0, std=0.02) # Initialize new components self.image_token_proj.apply(_init_layer) # Initialize tag embeddings with normal distribution nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02) # Initialize tag bias nn.init.zeros_(self.tag_bias) def _print_parameter_count(self): """Print parameter statistics""" total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) backbone_params = sum(p.numel() for p in self.backbone.parameters()) print(f"📊 Parameter Statistics:") print(f" Total parameters: {total_params/1e6:.1f}M") print(f" Trainable parameters: {trainable_params/1e6:.1f}M") print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M") print(f" Backbone parameters: {backbone_params/1e6:.1f}M") if self.use_gradient_checkpointing: print(f" 🔄 Gradient checkpointing enabled for memory efficiency") @property def debug(self): return self._flags['debug'] @property def model_stats(self): return self._flags['model_stats'] def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None): """Select candidate tags - no ground truth inclusion""" batch_size = initial_logits.size(0) # Simply select top K candidates based on initial predictions top_probs, top_indices = torch.topk( torch.sigmoid(initial_logits), k=min(self.tag_context_size, self.total_tags), dim=1, largest=True, sorted=True ) return top_indices def _analyze_predictions(self, predictions, tag_indices): """Analyze prediction patterns""" if not self.model_stats: return {} if torch._dynamo.is_compiling(): return {} with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): probs = torch.sigmoid(predictions) relevant_probs = torch.gather(probs, 1, tag_indices) return { 'prediction_confidence': relevant_probs.mean().item(), 'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(), 'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(), 'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(), } def forward(self, x, targets=None, hard_negatives=None): """ Forward pass with ViT backbone, CLS token support and gradient-checkpointing. All arithmetic tensors stay in the backbone’s dtype (BF16 under autocast, FP32 otherwise). Anything that must mix dtypes is cast to match. """ batch_size = x.size(0) model_stats = {} if self.model_stats else {} # ------------------------------------------------------------------ # 1. Backbone → patch map + CLS token # ------------------------------------------------------------------ patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W] # cls_token: [B, C] # ------------------------------------------------------------------ # 2. Tokens → global image vector # ------------------------------------------------------------------ image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W] image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C] # “Dual-pool”: mean-pool patches ⊕ CLS global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C] compute_dtype = global_features.dtype # BF16 or FP32 # ------------------------------------------------------------------ # 3. Initial logits (shared weights) # ------------------------------------------------------------------ tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C] tag_bias = self.tag_bias.to(compute_dtype) # [T] initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T] initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform initial_preds = initial_logits # alias # ------------------------------------------------------------------ # 4. Candidate set # ------------------------------------------------------------------ candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K] tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C] attended_features = self._checkpoint_cross_attention( # [B, K, C] tag_embeddings, image_tokens, image_tokens ) # ------------------------------------------------------------------ # 5. Score candidates & scatter back # ------------------------------------------------------------------ candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K] # --- align dtypes so scatter never throws --- if candidate_logits.dtype != initial_logits.dtype: candidate_logits = candidate_logits.to(initial_logits.dtype) refined_logits = initial_logits.clone() refined_logits.scatter_(1, candidate_indices, candidate_logits) refined_preds = refined_logits # ------------------------------------------------------------------ # 6. Optional stats # ------------------------------------------------------------------ if self.model_stats and targets is not None and not torch._dynamo.is_compiling(): model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds, candidate_indices) model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds, candidate_indices) return { 'initial_predictions': initial_preds, 'refined_predictions': refined_preds, 'selected_candidates': candidate_indices, 'model_stats': model_stats }