AION-Search / clip /models /projections.py
astronolan's picture
Add AION-Search Dash app for Hugging Face Spaces
c89f65f
import torch
import torch.nn as nn
from typing import Optional
import torch.nn.functional as F
class TextProjector(nn.Module):
"""Projects text embeddings to shared space."""
def __init__(
self,
input_dim: int = 3072,
output_dim: int = 1024,
hidden_dim: Optional[int] = None,
dropout: float = 0.1,
num_layers: int = 4,
):
"""
Initialize text projector.
Args:
input_dim: Dimension of text embeddings (3072)
output_dim: Dimension of shared embedding space
hidden_dim: Hidden layer dimension (default: 1024)
dropout: Dropout rate
num_layers: Number of residual layers (default: 2)
"""
super().__init__()
if hidden_dim is None:
hidden_dim = 1024
self.fc_in = nn.Linear(input_dim, hidden_dim)
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
) for _ in range(num_layers)
])
self.fc_out = nn.Linear(hidden_dim, output_dim)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize projection weights."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Project text embeddings to shared space.
Args:
x: Text embeddings (batch_size, input_dim)
Returns:
Projected embeddings (batch_size, output_dim)
"""
h = self.fc_in(x)
for blk in self.blocks: # residual MLP stack
h = h + blk(h)
h = self.fc_out(h)
return F.normalize(h, dim=-1, eps=1e-3)
class CrossAttentionImageProjector(nn.Module):
"""Simplified projector with self-attention + cross-attention."""
def __init__(
self,
input_dim: int = 768,
output_dim: int = 1024,
hidden_dim: Optional[int] = None,
dropout: float = 0.1,
num_layers: int = 2, # Kept for compatibility, not used
num_heads: int = 4, # Reduced from 8
):
"""
Initialize simplified cross-attention image projector.
Args:
input_dim: Dimension of AION embeddings (768)
output_dim: Dimension of shared embedding space (default: 1024)
hidden_dim: Hidden dimension for attention (default: output_dim)
dropout: Dropout rate
num_layers: Kept for compatibility but not used
num_heads: Number of attention heads (reduced to 4)
"""
super().__init__()
if hidden_dim is None:
hidden_dim = output_dim
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
# Project input to hidden dim
self.input_proj = nn.Linear(input_dim, hidden_dim)
# Token pooling to reduce sequence length
# 576 tokens -> 64 tokens (9x reduction)
self.token_pool = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=9, stride=9, padding=0)
# Single self-attention layer
self.self_attn_norm = nn.LayerNorm(hidden_dim)
self.self_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# MLP after self-attention
self.mlp1_norm = nn.LayerNorm(hidden_dim)
self.mlp1 = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Dropout(dropout)
)
# Learned query vector
self.query = nn.Parameter(torch.randn(1, 1, hidden_dim))
# Single cross-attention layer
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True
)
# Final MLP
self.final_norm = nn.LayerNorm(hidden_dim)
self.final_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 2), # Reduced from 4x
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, output_dim)
)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights."""
# Initialize query vector
nn.init.normal_(self.query, std=0.02)
# Initialize other weights
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Project image embeddings to shared space using self-attention + cross-attention.
Args:
x: Image embeddings (batch_size, n_tokens, input_dim)
Returns:
Projected embeddings (batch_size, output_dim)
"""
batch_size = x.shape[0]
x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input (handles [B, N, D])
# Project input
x = self.input_proj(x) # (B, N, hidden_dim)
# Pool tokens to reduce sequence length
x = x.transpose(1, 2) # (B, hidden_dim, N)
x = self.token_pool(x) # (B, hidden_dim, N//9)
x = x.transpose(1, 2) # (B, N//9, hidden_dim)
# Self-attention with residual on pooled tokens
x_norm = self.self_attn_norm(x)
x_attn, _ = self.self_attn(x_norm, x_norm, x_norm, need_weights=False)
x = x + x_attn
# MLP with residual
x = x + self.mlp1(self.mlp1_norm(x))
# Cross-attention with learned query
query = self.query.expand(batch_size, -1, -1) # (B, 1, hidden_dim)
q_norm = self.cross_attn_norm(query)
attended, _ = self.cross_attn(q_norm, x, x, need_weights=False)
query = query + attended
# Final processing
output = self.final_norm(query).squeeze(1) # (B, hidden_dim)
output = self.final_mlp(output) # (B, output_dim)
return F.normalize(output, dim=-1, eps=1e-3)
class SimpleImageProjector(nn.Module):
"""Simple projector for mean AION embeddings."""
def __init__(
self,
input_dim: int = 768,
output_dim: int = 1024,
hidden_dim: Optional[int] = None,
dropout: float = 0.1,
num_layers: int = 4,
):
"""
Initialize simple image projector.
Args:
input_dim: Dimension of AION embeddings (768)
output_dim: Dimension of shared embedding space
hidden_dim: Hidden layer dimension (default: 1024)
dropout: Dropout rate
num_layers: Number of residual layers (default: 4)
"""
super().__init__()
if hidden_dim is None:
hidden_dim = 1024
self.fc_in = nn.Linear(input_dim, hidden_dim)
self.blocks = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
) for _ in range(num_layers)
])
self.fc_out = nn.Linear(hidden_dim, output_dim)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize projection weights."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Project image embeddings to shared space.
Args:
x: Image embeddings (batch_size, input_dim)
Returns:
Projected embeddings (batch_size, output_dim)
"""
x = F.normalize(x, dim=-1, eps=1e-6) # Normalize AION embeddings input
h = self.fc_in(x)
for blk in self.blocks: # residual MLP stack
h = h + blk(h)
h = self.fc_out(h)
return F.normalize(h, dim=-1, eps=1e-3)