camie-tagger-v2-app / utils /model_loader.py
Camais03's picture
Upload 6 files
e7d3e33 verified
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import timm
class ViTWrapper(nn.Module):
"""Wrapper to make ViT compatible with feature extraction for ImageTagger"""
def __init__(self, vit_model):
super().__init__()
self.vit = vit_model
self.out_indices = (-1,) # mimic timm.features_only
# Get patch size and embedding dim from the model
self.patch_size = vit_model.patch_embed.patch_size[0]
self.embed_dim = vit_model.embed_dim
def forward(self, x):
B = x.size(0)
# ➊ patch tokens
x = self.vit.patch_embed(x) # (B, N, C)
# βž‹ prepend CLS
cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C)
x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C)
# ➌ add positional encodings (full, incl. CLS)
if self.vit.pos_embed is not None:
x = x + self.vit.pos_embed[:, : x.size(1), :]
x = self.vit.pos_drop(x)
for blk in self.vit.blocks:
x = blk(x)
x = self.vit.norm(x) # (B, 1+N, C)
# ➍ split back out
cls_final = x[:, 0] # (B, C)
patch_tokens = x[:, 1:] # (B, N, C)
# ➎ reshape patches to (B, C, H, W)
B, N, C = patch_tokens.shape
h = w = int(N ** 0.5) # square assumption
patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)
# Return **both**: (patch map, CLS)
return patch_features, cls_final
def set_grad_checkpointing(self, enable=True):
"""Enable gradient checkpointing if supported"""
if hasattr(self.vit, 'set_grad_checkpointing'):
self.vit.set_grad_checkpointing(enable)
return True
return False
class ImageTagger(nn.Module):
"""
ImageTagger with Vision Transformer backbone
"""
def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224',
num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256,
use_gradient_checkpointing=False, img_size=224):
super().__init__()
# Store checkpointing config
self.use_gradient_checkpointing = use_gradient_checkpointing
self.model_name = model_name
self.img_size = img_size
# Debug and stats flags
self._flags = {
'debug': False,
'model_stats': True
}
# Core model config
self.dataset = dataset
self.tag_context_size = tag_context_size
self.total_tags = total_tags
print(f"πŸ—οΈ Building ImageTagger with ViT backbone and {total_tags} tags")
print(f" Backbone: {model_name}")
print(f" Image size: {img_size}x{img_size}")
print(f" Tag context size: {tag_context_size}")
print(f" Gradient checkpointing: {use_gradient_checkpointing}")
print(f" 🎯 Custom embeddings, PyTorch native attention, no ground truth inclusion")
# 1. Vision Transformer Backbone
print("πŸ“¦ Loading Vision Transformer backbone...")
self._load_vit_backbone()
# Get backbone dimensions by running a test forward pass
self._determine_backbone_dimensions()
self.embedding_dim = self.backbone.embed_dim
# 2. Custom Tag Embeddings (no CLIP)
print("🎯 Using custom tag embeddings (no CLIP)")
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
# 3. Shared weights approach - tag bias for initial predictions
print("πŸ”— Using shared weights between initial head and tag embeddings")
self.tag_bias = nn.Parameter(torch.zeros(total_tags))
# 4. Image token extraction (for attention AND global pooling)
self.image_token_proj = nn.Identity()
# 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation)
self.cross_attention = nn.MultiheadAttention(
embed_dim=self.embedding_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True # Use (batch, seq, feature) format
)
self.cross_norm = nn.LayerNorm(self.embedding_dim)
# Initialize weights
self._init_weights()
# Enable gradient checkpointing
if self.use_gradient_checkpointing:
self._enable_gradient_checkpointing()
print(f"βœ… ImageTagger with ViT initialized!")
self._print_parameter_count()
def _load_vit_backbone(self):
"""Load Vision Transformer model from timm"""
print(f" Loading from timm: {self.model_name}")
# Load the ViT model (not features_only, we want the full model for token extraction)
vit_model = timm.create_model(
self.model_name,
pretrained=True,
img_size=self.img_size,
num_classes=0 # Remove classification head
)
# Wrap it in our compatibility layer
self.backbone = ViTWrapper(vit_model)
print(f" βœ… ViT loaded successfully")
print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}")
print(f" Embed dim: {self.backbone.embed_dim}")
def _determine_backbone_dimensions(self):
"""Determine backbone output dimensions"""
print(" πŸ” Determining backbone dimensions...")
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
# Create a dummy input
dummy_input = torch.randn(1, 3, self.img_size, self.img_size)
# Get features
backbone_features, cls_dummy = self.backbone(dummy_input)
feature_tensor = backbone_features
self.backbone_dim = feature_tensor.shape[1]
self.feature_map_size = feature_tensor.shape[2]
print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial")
print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}")
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
print("πŸ”„ Enabling gradient checkpointing...")
# Enable checkpointing for ViT backbone
if self.backbone.set_grad_checkpointing(True):
print(" βœ… ViT backbone checkpointing enabled")
else:
print(" ⚠️ ViT backbone doesn't support built-in checkpointing, will checkpoint manually")
def _checkpoint_backbone(self, x):
"""Wrapper for backbone with gradient checkpointing"""
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(self.backbone, x, use_reentrant=False)
else:
return self.backbone(x)
def _checkpoint_image_proj(self, x):
"""Wrapper for image projection with gradient checkpointing"""
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False)
else:
return self.image_token_proj(x)
def _checkpoint_cross_attention(self, query, key, value):
"""Wrapper for cross attention with gradient checkpointing"""
def _attention_forward(q, k, v):
attended_features, _ = self.cross_attention(query=q, key=k, value=v)
return self.cross_norm(attended_features)
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False)
else:
return _attention_forward(query, key, value)
def _checkpoint_candidate_selection(self, initial_logits):
"""Wrapper for candidate selection with gradient checkpointing"""
def _candidate_forward(logits):
return self._get_candidate_tags(logits)
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False)
else:
return _candidate_forward(initial_logits)
def _checkpoint_final_scoring(self, attended_features, candidate_indices):
"""Wrapper for final scoring with gradient checkpointing"""
def _scoring_forward(features, indices):
emb = self.tag_embedding(indices)
# BF16 in, BF16 out
return (features * emb).sum(dim=-1)
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False)
else:
return _scoring_forward(attended_features, candidate_indices)
def _init_weights(self):
"""Initialize weights for new modules"""
def _init_layer(layer):
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
elif isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
if layer.bias is not None:
nn.init.zeros_(layer.bias)
elif isinstance(layer, nn.Embedding):
nn.init.normal_(layer.weight, mean=0, std=0.02)
# Initialize new components
self.image_token_proj.apply(_init_layer)
# Initialize tag embeddings with normal distribution
nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02)
# Initialize tag bias
nn.init.zeros_(self.tag_bias)
def _print_parameter_count(self):
"""Print parameter statistics"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
backbone_params = sum(p.numel() for p in self.backbone.parameters())
print(f"πŸ“Š Parameter Statistics:")
print(f" Total parameters: {total_params/1e6:.1f}M")
print(f" Trainable parameters: {trainable_params/1e6:.1f}M")
print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M")
print(f" Backbone parameters: {backbone_params/1e6:.1f}M")
if self.use_gradient_checkpointing:
print(f" πŸ”„ Gradient checkpointing enabled for memory efficiency")
@property
def debug(self):
return self._flags['debug']
@property
def model_stats(self):
return self._flags['model_stats']
def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None):
"""Select candidate tags - no ground truth inclusion"""
batch_size = initial_logits.size(0)
# Simply select top K candidates based on initial predictions
top_probs, top_indices = torch.topk(
torch.sigmoid(initial_logits),
k=min(self.tag_context_size, self.total_tags),
dim=1, largest=True, sorted=True
)
return top_indices
def _analyze_predictions(self, predictions, tag_indices):
"""Analyze prediction patterns"""
if not self.model_stats:
return {}
if torch._dynamo.is_compiling():
return {}
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
probs = torch.sigmoid(predictions)
relevant_probs = torch.gather(probs, 1, tag_indices)
return {
'prediction_confidence': relevant_probs.mean().item(),
'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(),
'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(),
'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(),
}
def forward(self, x, targets=None, hard_negatives=None):
"""
Forward pass with ViT backbone, CLS token support and gradient-checkpointing.
All arithmetic tensors stay in the backbone’s dtype (BF16 under autocast,
FP32 otherwise). Anything that must mix dtypes is cast to match.
"""
batch_size = x.size(0)
model_stats = {} if self.model_stats else {}
# ------------------------------------------------------------------
# 1. Backbone β†’ patch map + CLS token
# ------------------------------------------------------------------
patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W]
# cls_token: [B, C]
# ------------------------------------------------------------------
# 2. Tokens β†’ global image vector
# ------------------------------------------------------------------
image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W]
image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C]
# β€œDual-pool”: mean-pool patches βŠ• CLS
global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C]
compute_dtype = global_features.dtype # BF16 or FP32
# ------------------------------------------------------------------
# 3. Initial logits (shared weights)
# ------------------------------------------------------------------
tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C]
tag_bias = self.tag_bias.to(compute_dtype) # [T]
initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T]
initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform
initial_preds = initial_logits # alias
# ------------------------------------------------------------------
# 4. Candidate set
# ------------------------------------------------------------------
candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K]
tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C]
attended_features = self._checkpoint_cross_attention( # [B, K, C]
tag_embeddings, image_tokens, image_tokens
)
# ------------------------------------------------------------------
# 5. Score candidates & scatter back
# ------------------------------------------------------------------
candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K]
# --- align dtypes so scatter never throws ---
if candidate_logits.dtype != initial_logits.dtype:
candidate_logits = candidate_logits.to(initial_logits.dtype)
refined_logits = initial_logits.clone()
refined_logits.scatter_(1, candidate_indices, candidate_logits)
refined_preds = refined_logits
# ------------------------------------------------------------------
# 6. Optional stats
# ------------------------------------------------------------------
if self.model_stats and targets is not None and not torch._dynamo.is_compiling():
model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds,
candidate_indices)
model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds,
candidate_indices)
return {
'initial_predictions': initial_preds,
'refined_predictions': refined_preds,
'selected_candidates': candidate_indices,
'model_stats': model_stats
}
def predict