AION-Search / clip /models /clip_model.py
astronolan's picture
Add AION-Search Dash app for Hugging Face Spaces
c89f65f
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict
from .projections import TextProjector, CrossAttentionImageProjector, SimpleImageProjector
class GalaxyClipModel(nn.Module):
"""CLIP model for aligning galaxy images and text descriptions."""
def __init__(
self,
image_input_dim: int = 768,
text_input_dim: int = 3072,
embedding_dim: int = 1024,
image_hidden_dim: int = 768,
text_hidden_dim: int = 1024,
dropout: float = 0.1,
use_mean_embeddings: bool = True
):
"""
Initialize CLIP model.
Args:
image_input_dim: AION embedding dimension
text_input_dim: Text embedding dimension
embedding_dim: Shared embedding space dimension
image_hidden_dim: Hidden dimension for image projector
text_hidden_dim: Hidden dimension for text projector
dropout: Dropout rate
use_mean_embeddings: Whether using mean embeddings (True) or full embeddings (False)
"""
super().__init__()
self.embedding_dim = embedding_dim
self.use_mean_embeddings = use_mean_embeddings
# Choose appropriate image projector based on embedding type
if use_mean_embeddings:
# Simple projector for mean embeddings (1D vectors)
self.image_projector = SimpleImageProjector(
input_dim=image_input_dim,
output_dim=embedding_dim,
hidden_dim=image_hidden_dim,
dropout=dropout
)
else:
# Cross-attention projector for full embeddings (2D sequences)
self.image_projector = CrossAttentionImageProjector(
input_dim=image_input_dim,
output_dim=embedding_dim,
hidden_dim=image_hidden_dim,
dropout=dropout
)
self.text_projector = TextProjector(
input_dim=text_input_dim,
output_dim=embedding_dim,
hidden_dim=text_hidden_dim,
dropout=dropout
)
# Learnable logit scale parameter initialized to standard CLIP temperature 1/0.07
# Using log parameterization for numerical stability
self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07, dtype=torch.float32)))
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass for CLIP training.
Args:
batch: Dictionary containing 'image_embedding' and 'text_embedding'
Returns:
Dictionary with projected embeddings and logits
"""
image_features = batch['image_embedding']
text_features = batch['text_embedding']
# Project to shared space and normalize
image_features = self.image_projector(image_features)
text_features = self.text_projector(text_features)
# Compute similarity matrix with learnable logit scale
# Clamp after exp to preserve gradients
logit_scale = self.logit_scale.exp().clamp(max=100)
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logits_per_image.T
return {
'image_features': image_features,
'text_features': text_features,
'logits_per_image': logits_per_image,
'logits_per_text': logits_per_text,
'logit_scale': logit_scale
}
def compute_contrastive_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Compute contrastive loss (InfoNCE).
Args:
outputs: Model outputs from forward pass
Returns:
Contrastive loss
"""
logits_per_image = outputs['logits_per_image']
logits_per_text = outputs['logits_per_text']
batch_size = logits_per_image.shape[0]
labels = torch.arange(batch_size, device=logits_per_image.device)
# Cross-entropy loss for both directions
loss_i2t = F.cross_entropy(logits_per_image, labels)
loss_t2i = F.cross_entropy(logits_per_text, labels)
return (loss_i2t + loss_t2i) / 2