Spaces:
Running
Running
File size: 16,834 Bytes
31e20c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 |
import torch
import torch.nn as nn
from torch.nn import GroupNorm, LayerNorm
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import timm
class ViTWrapper(nn.Module):
"""Wrapper to make ViT compatible with feature extraction for ImageTagger"""
def __init__(self, vit_model):
super().__init__()
self.vit = vit_model
self.out_indices = (-1,) # mimic timm.features_only
# Get patch size and embedding dim from the model
self.patch_size = vit_model.patch_embed.patch_size[0]
self.embed_dim = vit_model.embed_dim
def forward(self, x):
B = x.size(0)
# β patch tokens
x = self.vit.patch_embed(x) # (B, N, C)
# β prepend CLS
cls_tok = self.vit.cls_token.expand(B, -1, -1) # (B, 1, C)
x = torch.cat((cls_tok, x), dim=1) # (B, 1+N, C)
# β add positional encodings (full, incl. CLS)
if self.vit.pos_embed is not None:
x = x + self.vit.pos_embed[:, : x.size(1), :]
x = self.vit.pos_drop(x)
for blk in self.vit.blocks:
x = blk(x)
x = self.vit.norm(x) # (B, 1+N, C)
# β split back out
cls_final = x[:, 0] # (B, C)
patch_tokens = x[:, 1:] # (B, N, C)
# β reshape patches to (B, C, H, W)
B, N, C = patch_tokens.shape
h = w = int(N ** 0.5) # square assumption
patch_features = patch_tokens.permute(0, 2, 1).reshape(B, C, h, w)
# Return **both**: (patch map, CLS)
return patch_features, cls_final
def set_grad_checkpointing(self, enable=True):
"""Enable gradient checkpointing if supported"""
if hasattr(self.vit, 'set_grad_checkpointing'):
self.vit.set_grad_checkpointing(enable)
return True
return False
class ImageTagger(nn.Module):
"""
ImageTagger with Vision Transformer backbone
"""
def __init__(self, total_tags, dataset, model_name='vit_base_patch16_224',
num_heads=16, dropout=0.1, pretrained=True, tag_context_size=256,
use_gradient_checkpointing=False, img_size=224):
super().__init__()
# Store checkpointing config
self.use_gradient_checkpointing = use_gradient_checkpointing
self.model_name = model_name
self.img_size = img_size
# Debug and stats flags
self._flags = {
'debug': False,
'model_stats': True
}
# Core model config
self.dataset = dataset
self.tag_context_size = tag_context_size
self.total_tags = total_tags
print(f"ποΈ Building ImageTagger with ViT backbone and {total_tags} tags")
print(f" Backbone: {model_name}")
print(f" Image size: {img_size}x{img_size}")
print(f" Tag context size: {tag_context_size}")
print(f" Gradient checkpointing: {use_gradient_checkpointing}")
print(f" π― Custom embeddings, PyTorch native attention, no ground truth inclusion")
# 1. Vision Transformer Backbone
print("π¦ Loading Vision Transformer backbone...")
self._load_vit_backbone()
# Get backbone dimensions by running a test forward pass
self._determine_backbone_dimensions()
self.embedding_dim = self.backbone.embed_dim
# 2. Custom Tag Embeddings (no CLIP)
print("π― Using custom tag embeddings (no CLIP)")
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
# 3. Shared weights approach - tag bias for initial predictions
print("π Using shared weights between initial head and tag embeddings")
self.tag_bias = nn.Parameter(torch.zeros(total_tags))
# 4. Image token extraction (for attention AND global pooling)
self.image_token_proj = nn.Identity()
# 5. Tags-as-queries cross-attention (using PyTorch's optimized implementation)
self.cross_attention = nn.MultiheadAttention(
embed_dim=self.embedding_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True # Use (batch, seq, feature) format
)
self.cross_norm = nn.LayerNorm(self.embedding_dim)
# Initialize weights
self._init_weights()
# Enable gradient checkpointing
if self.use_gradient_checkpointing:
self._enable_gradient_checkpointing()
print(f"β
ImageTagger with ViT initialized!")
self._print_parameter_count()
def _load_vit_backbone(self):
"""Load Vision Transformer model from timm"""
print(f" Loading from timm: {self.model_name}")
# Load the ViT model (not features_only, we want the full model for token extraction)
vit_model = timm.create_model(
self.model_name,
pretrained=True,
img_size=self.img_size,
num_classes=0 # Remove classification head
)
# Wrap it in our compatibility layer
self.backbone = ViTWrapper(vit_model)
print(f" β
ViT loaded successfully")
print(f" Patch size: {self.backbone.patch_size}x{self.backbone.patch_size}")
print(f" Embed dim: {self.backbone.embed_dim}")
def _determine_backbone_dimensions(self):
"""Determine backbone output dimensions"""
print(" π Determining backbone dimensions...")
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
# Create a dummy input
dummy_input = torch.randn(1, 3, self.img_size, self.img_size)
# Get features
backbone_features, cls_dummy = self.backbone(dummy_input)
feature_tensor = backbone_features
self.backbone_dim = feature_tensor.shape[1]
self.feature_map_size = feature_tensor.shape[2]
print(f" Backbone output: {self.backbone_dim}D, {self.feature_map_size}x{self.feature_map_size} spatial")
print(f" Total patch tokens: {self.feature_map_size * self.feature_map_size}")
def _enable_gradient_checkpointing(self):
"""Enable gradient checkpointing for memory efficiency"""
print("π Enabling gradient checkpointing...")
# Enable checkpointing for ViT backbone
if self.backbone.set_grad_checkpointing(True):
print(" β
ViT backbone checkpointing enabled")
else:
print(" β οΈ ViT backbone doesn't support built-in checkpointing, will checkpoint manually")
def _checkpoint_backbone(self, x):
"""Wrapper for backbone with gradient checkpointing"""
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(self.backbone, x, use_reentrant=False)
else:
return self.backbone(x)
def _checkpoint_image_proj(self, x):
"""Wrapper for image projection with gradient checkpointing"""
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(self.image_token_proj, x, use_reentrant=False)
else:
return self.image_token_proj(x)
def _checkpoint_cross_attention(self, query, key, value):
"""Wrapper for cross attention with gradient checkpointing"""
def _attention_forward(q, k, v):
attended_features, _ = self.cross_attention(query=q, key=k, value=v)
return self.cross_norm(attended_features)
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(_attention_forward, query, key, value, use_reentrant=False)
else:
return _attention_forward(query, key, value)
def _checkpoint_candidate_selection(self, initial_logits):
"""Wrapper for candidate selection with gradient checkpointing"""
def _candidate_forward(logits):
return self._get_candidate_tags(logits)
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(_candidate_forward, initial_logits, use_reentrant=False)
else:
return _candidate_forward(initial_logits)
def _checkpoint_final_scoring(self, attended_features, candidate_indices):
"""Wrapper for final scoring with gradient checkpointing"""
def _scoring_forward(features, indices):
emb = self.tag_embedding(indices)
# BF16 in, BF16 out
return (features * emb).sum(dim=-1)
if self.use_gradient_checkpointing and self.training:
return checkpoint.checkpoint(_scoring_forward, attended_features, candidate_indices, use_reentrant=False)
else:
return _scoring_forward(attended_features, candidate_indices)
def _init_weights(self):
"""Initialize weights for new modules"""
def _init_layer(layer):
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
elif isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
if layer.bias is not None:
nn.init.zeros_(layer.bias)
elif isinstance(layer, nn.Embedding):
nn.init.normal_(layer.weight, mean=0, std=0.02)
# Initialize new components
self.image_token_proj.apply(_init_layer)
# Initialize tag embeddings with normal distribution
nn.init.normal_(self.tag_embedding.weight, mean=0, std=0.02)
# Initialize tag bias
nn.init.zeros_(self.tag_bias)
def _print_parameter_count(self):
"""Print parameter statistics"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
backbone_params = sum(p.numel() for p in self.backbone.parameters())
print(f"π Parameter Statistics:")
print(f" Total parameters: {total_params/1e6:.1f}M")
print(f" Trainable parameters: {trainable_params/1e6:.1f}M")
print(f" Frozen parameters: {(total_params-trainable_params)/1e6:.1f}M")
print(f" Backbone parameters: {backbone_params/1e6:.1f}M")
if self.use_gradient_checkpointing:
print(f" π Gradient checkpointing enabled for memory efficiency")
@property
def debug(self):
return self._flags['debug']
@property
def model_stats(self):
return self._flags['model_stats']
def _get_candidate_tags(self, initial_logits, target_tags=None, hard_negatives=None):
"""Select candidate tags - no ground truth inclusion"""
batch_size = initial_logits.size(0)
# Simply select top K candidates based on initial predictions
top_probs, top_indices = torch.topk(
torch.sigmoid(initial_logits),
k=min(self.tag_context_size, self.total_tags),
dim=1, largest=True, sorted=True
)
return top_indices
def _analyze_predictions(self, predictions, tag_indices):
"""Analyze prediction patterns"""
if not self.model_stats:
return {}
if torch._dynamo.is_compiling():
return {}
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
probs = torch.sigmoid(predictions)
relevant_probs = torch.gather(probs, 1, tag_indices)
return {
'prediction_confidence': relevant_probs.mean().item(),
'prediction_entropy': -(relevant_probs * torch.log(relevant_probs + 1e-9)).mean().item(),
'high_confidence_ratio': (relevant_probs > 0.7).float().mean().item(),
'above_threshold_ratio': (relevant_probs > 0.5).float().mean().item(),
}
def forward(self, x, targets=None, hard_negatives=None):
"""
Forward pass with ViT backbone, CLS token support and gradient-checkpointing.
All arithmetic tensors stay in the backboneβs dtype (BF16 under autocast,
FP32 otherwise). Anything that must mix dtypes is cast to match.
"""
batch_size = x.size(0)
model_stats = {} if self.model_stats else {}
# ------------------------------------------------------------------
# 1. Backbone β patch map + CLS token
# ------------------------------------------------------------------
patch_map, cls_token = self._checkpoint_backbone(x) # patch_map: [B, C, H, W]
# cls_token: [B, C]
# ------------------------------------------------------------------
# 2. Tokens β global image vector
# ------------------------------------------------------------------
image_tokens_4d = self._checkpoint_image_proj(patch_map) # [B, C, H, W]
image_tokens = image_tokens_4d.flatten(2).transpose(1, 2) # [B, N, C]
# βDual-poolβ: mean-pool patches β CLS
global_features = 0.5 * (image_tokens.mean(dim=1, dtype=image_tokens.dtype) + cls_token) # [B, C]
compute_dtype = global_features.dtype # BF16 or FP32
# ------------------------------------------------------------------
# 3. Initial logits (shared weights)
# ------------------------------------------------------------------
tag_weights = self.tag_embedding.weight.to(compute_dtype) # [T, C]
tag_bias = self.tag_bias.to(compute_dtype) # [T]
initial_logits = global_features @ tag_weights.t() + tag_bias # [B, T]
initial_logits = initial_logits.to(compute_dtype) # keep dtype uniform
initial_preds = initial_logits # alias
# ------------------------------------------------------------------
# 4. Candidate set
# ------------------------------------------------------------------
candidate_indices = self._checkpoint_candidate_selection(initial_logits) # [B, K]
tag_embeddings = self.tag_embedding(candidate_indices).to(compute_dtype) # [B, K, C]
attended_features = self._checkpoint_cross_attention( # [B, K, C]
tag_embeddings, image_tokens, image_tokens
)
# ------------------------------------------------------------------
# 5. Score candidates & scatter back
# ------------------------------------------------------------------
candidate_logits = self._checkpoint_final_scoring(attended_features, candidate_indices) # [B, K]
# --- align dtypes so scatter never throws ---
if candidate_logits.dtype != initial_logits.dtype:
candidate_logits = candidate_logits.to(initial_logits.dtype)
refined_logits = initial_logits.clone()
refined_logits.scatter_(1, candidate_indices, candidate_logits)
refined_preds = refined_logits
# ------------------------------------------------------------------
# 6. Optional stats
# ------------------------------------------------------------------
if self.model_stats and targets is not None and not torch._dynamo.is_compiling():
model_stats['initial_prediction_stats'] = self._analyze_predictions(initial_preds,
candidate_indices)
model_stats['refined_prediction_stats'] = self._analyze_predictions(refined_preds,
candidate_indices)
return {
'initial_predictions': initial_preds,
'refined_predictions': refined_preds,
'selected_candidates': candidate_indices,
'model_stats': model_stats
} |