Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from typing import Optional | |
| import torch.nn.functional as F | |
| class TextProjector(nn.Module): | |
| """Projects text embeddings to shared space.""" | |
| def __init__( | |
| self, | |
| input_dim: int = 3072, | |
| output_dim: int = 1024, | |
| hidden_dim: Optional[int] = None, | |
| dropout: float = 0.1, | |
| num_layers: int = 4, | |
| ): | |
| """ | |
| Initialize text projector. | |
| Args: | |
| input_dim: Dimension of text embeddings (3072) | |
| output_dim: Dimension of shared embedding space | |
| hidden_dim: Hidden layer dimension (default: 1024) | |
| dropout: Dropout rate | |
| num_layers: Number of residual layers (default: 2) | |
| """ | |
| super().__init__() | |
| if hidden_dim is None: | |
| hidden_dim = 1024 | |
| self.fc_in = nn.Linear(input_dim, hidden_dim) | |
| self.blocks = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.fc_out = nn.Linear(hidden_dim, output_dim) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize projection weights.""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Project text embeddings to shared space. | |
| Args: | |
| x: Text embeddings (batch_size, input_dim) | |
| Returns: | |
| Projected embeddings (batch_size, output_dim) | |
| """ | |
| h = self.fc_in(x) | |
| for blk in self.blocks: # residual MLP stack | |
| h = h + blk(h) | |
| h = self.fc_out(h) | |
| return F.normalize(h, dim=-1, eps=1e-3) | |
| class CrossAttentionImageProjector(nn.Module): | |
| """Simplified projector with self-attention + cross-attention.""" | |
| def __init__( | |
| self, | |
| input_dim: int = 768, | |
| output_dim: int = 1024, | |
| hidden_dim: Optional[int] = None, | |
| dropout: float = 0.1, | |
| num_layers: int = 2, # Kept for compatibility, not used | |
| num_heads: int = 4, # Reduced from 8 | |
| ): | |
| """ | |
| Initialize simplified cross-attention image projector. | |
| Args: | |
| input_dim: Dimension of AION embeddings (768) | |
| output_dim: Dimension of shared embedding space (default: 1024) | |
| hidden_dim: Hidden dimension for attention (default: output_dim) | |
| dropout: Dropout rate | |
| num_layers: Kept for compatibility but not used | |
| num_heads: Number of attention heads (reduced to 4) | |
| """ | |
| super().__init__() | |
| if hidden_dim is None: | |
| hidden_dim = output_dim | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.output_dim = output_dim | |
| # Project input to hidden dim | |
| self.input_proj = nn.Linear(input_dim, hidden_dim) | |
| # Token pooling to reduce sequence length | |
| # 576 tokens -> 64 tokens (9x reduction) | |
| self.token_pool = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, stride=9, padding=0) | |
| # Single self-attention layer | |
| self.self_attn_norm = nn.LayerNorm(hidden_dim) | |
| self.self_attn = nn.MultiheadAttention( | |
| embed_dim=hidden_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| # MLP after self-attention | |
| self.mlp1_norm = nn.LayerNorm(hidden_dim) | |
| self.mlp1 = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim * 2, hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Learned query vector | |
| self.query = nn.Parameter(torch.randn(1, 1, hidden_dim)) | |
| # Single cross-attention layer | |
| self.cross_attn_norm = nn.LayerNorm(hidden_dim) | |
| self.cross_attn = nn.MultiheadAttention( | |
| embed_dim=hidden_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| # Final MLP | |
| self.final_norm = nn.LayerNorm(hidden_dim) | |
| self.final_mlp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim * 2, output_dim) | |
| ) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize weights.""" | |
| # Initialize query vector | |
| nn.init.normal_(self.query, std=0.02) | |
| # Initialize other weights | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Project image embeddings to shared space using self-attention + cross-attention. | |
| Args: | |
| x: Image embeddings (batch_size, n_tokens, input_dim) | |
| Returns: | |
| Projected embeddings (batch_size, output_dim) | |
| """ | |
| batch_size = x.shape[0] | |
| x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input (handles [B, N, D]) | |
| # Project input | |
| x = self.input_proj(x) # (B, N, hidden_dim) | |
| # Pool tokens to reduce sequence length | |
| x = x.transpose(1, 2) # (B, hidden_dim, N) | |
| x = self.token_pool(x) # (B, hidden_dim, N//9) | |
| x = x.transpose(1, 2) # (B, N//9, hidden_dim) | |
| # Self-attention with residual on pooled tokens | |
| x_norm = self.self_attn_norm(x) | |
| x_attn, _ = self.self_attn(x_norm, x_norm, x_norm, need_weights=False) | |
| x = x + x_attn | |
| # MLP with residual | |
| x = x + self.mlp1(self.mlp1_norm(x)) | |
| # Cross-attention with learned query | |
| query = self.query.expand(batch_size, -1, -1) # (B, 1, hidden_dim) | |
| q_norm = self.cross_attn_norm(query) | |
| attended, _ = self.cross_attn(q_norm, x, x, need_weights=False) | |
| query = query + attended | |
| # Final processing | |
| output = self.final_norm(query).squeeze(1) # (B, hidden_dim) | |
| output = self.final_mlp(output) # (B, output_dim) | |
| return F.normalize(output, dim=-1, eps=1e-3) | |
| class SimpleImageProjector(nn.Module): | |
| """Simple projector for mean AION embeddings.""" | |
| def __init__( | |
| self, | |
| input_dim: int = 768, | |
| output_dim: int = 1024, | |
| hidden_dim: Optional[int] = None, | |
| dropout: float = 0.1, | |
| num_layers: int = 4, | |
| ): | |
| """ | |
| Initialize simple image projector. | |
| Args: | |
| input_dim: Dimension of AION embeddings (768) | |
| output_dim: Dimension of shared embedding space | |
| hidden_dim: Hidden layer dimension (default: 1024) | |
| dropout: Dropout rate | |
| num_layers: Number of residual layers (default: 4) | |
| """ | |
| super().__init__() | |
| if hidden_dim is None: | |
| hidden_dim = 1024 | |
| self.fc_in = nn.Linear(input_dim, hidden_dim) | |
| self.blocks = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.LayerNorm(hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.fc_out = nn.Linear(hidden_dim, output_dim) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize projection weights.""" | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Project image embeddings to shared space. | |
| Args: | |
| x: Image embeddings (batch_size, input_dim) | |
| Returns: | |
| Projected embeddings (batch_size, output_dim) | |
| """ | |
| x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input | |
| h = self.fc_in(x) | |
| for blk in self.blocks: # residual MLP stack | |
| h = h + blk(h) | |
| h = self.fc_out(h) | |
| return F.normalize(h, dim=-1, eps=1e-3) |