Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Dict | |
| from .projections import TextProjector, CrossAttentionImageProjector, SimpleImageProjector | |
| class GalaxyClipModel(nn.Module): | |
| """CLIP model for aligning galaxy images and text descriptions.""" | |
| def __init__( | |
| self, | |
| image_input_dim: int = 768, | |
| text_input_dim: int = 3072, | |
| embedding_dim: int = 1024, | |
| image_hidden_dim: int = 768, | |
| text_hidden_dim: int = 1024, | |
| dropout: float = 0.1, | |
| use_mean_embeddings: bool = True | |
| ): | |
| """ | |
| Initialize CLIP model. | |
| Args: | |
| image_input_dim: AION embedding dimension | |
| text_input_dim: Text embedding dimension | |
| embedding_dim: Shared embedding space dimension | |
| image_hidden_dim: Hidden dimension for image projector | |
| text_hidden_dim: Hidden dimension for text projector | |
| dropout: Dropout rate | |
| use_mean_embeddings: Whether using mean embeddings (True) or full embeddings (False) | |
| """ | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.use_mean_embeddings = use_mean_embeddings | |
| # Choose appropriate image projector based on embedding type | |
| if use_mean_embeddings: | |
| # Simple projector for mean embeddings (1D vectors) | |
| self.image_projector = SimpleImageProjector( | |
| input_dim=image_input_dim, | |
| output_dim=embedding_dim, | |
| hidden_dim=image_hidden_dim, | |
| dropout=dropout | |
| ) | |
| else: | |
| # Cross-attention projector for full embeddings (2D sequences) | |
| self.image_projector = CrossAttentionImageProjector( | |
| input_dim=image_input_dim, | |
| output_dim=embedding_dim, | |
| hidden_dim=image_hidden_dim, | |
| dropout=dropout | |
| ) | |
| self.text_projector = TextProjector( | |
| input_dim=text_input_dim, | |
| output_dim=embedding_dim, | |
| hidden_dim=text_hidden_dim, | |
| dropout=dropout | |
| ) | |
| # Learnable logit scale parameter initialized to standard CLIP temperature 1/0.07 | |
| # Using log parameterization for numerical stability | |
| self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07, dtype=torch.float32))) | |
| def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Forward pass for CLIP training. | |
| Args: | |
| batch: Dictionary containing 'image_embedding' and 'text_embedding' | |
| Returns: | |
| Dictionary with projected embeddings and logits | |
| """ | |
| image_features = batch['image_embedding'] | |
| text_features = batch['text_embedding'] | |
| # Project to shared space and normalize | |
| image_features = self.image_projector(image_features) | |
| text_features = self.text_projector(text_features) | |
| # Compute similarity matrix with learnable logit scale | |
| # Clamp after exp to preserve gradients | |
| logit_scale = self.logit_scale.exp().clamp(max=100) | |
| logits_per_image = logit_scale * image_features @ text_features.T | |
| logits_per_text = logits_per_image.T | |
| return { | |
| 'image_features': image_features, | |
| 'text_features': text_features, | |
| 'logits_per_image': logits_per_image, | |
| 'logits_per_text': logits_per_text, | |
| 'logit_scale': logit_scale | |
| } | |
| def compute_contrastive_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Compute contrastive loss (InfoNCE). | |
| Args: | |
| outputs: Model outputs from forward pass | |
| Returns: | |
| Contrastive loss | |
| """ | |
| logits_per_image = outputs['logits_per_image'] | |
| logits_per_text = outputs['logits_per_text'] | |
| batch_size = logits_per_image.shape[0] | |
| labels = torch.arange(batch_size, device=logits_per_image.device) | |
| # Cross-entropy loss for both directions | |
| loss_i2t = F.cross_entropy(logits_per_image, labels) | |
| loss_t2i = F.cross_entropy(logits_per_text, labels) | |
| return (loss_i2t + loss_t2i) / 2 |