File size: 4,362 Bytes
c89f65f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict

from .projections import TextProjector, CrossAttentionImageProjector, SimpleImageProjector

class GalaxyClipModel(nn.Module):
    """CLIP model for aligning galaxy images and text descriptions."""
    
    def __init__(
        self,
        image_input_dim: int = 768,
        text_input_dim: int = 3072,
        embedding_dim: int = 1024,
        image_hidden_dim: int = 768,
        text_hidden_dim: int = 1024,
        dropout: float = 0.1,
        use_mean_embeddings: bool = True
    ):
        """
        Initialize CLIP model.
        
        Args:
            image_input_dim: AION embedding dimension
            text_input_dim: Text embedding dimension  
            embedding_dim: Shared embedding space dimension
            image_hidden_dim: Hidden dimension for image projector
            text_hidden_dim: Hidden dimension for text projector
            dropout: Dropout rate
            use_mean_embeddings: Whether using mean embeddings (True) or full embeddings (False)
        """
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.use_mean_embeddings = use_mean_embeddings
        
        # Choose appropriate image projector based on embedding type
        if use_mean_embeddings:
            # Simple projector for mean embeddings (1D vectors)
            self.image_projector = SimpleImageProjector(
                input_dim=image_input_dim,
                output_dim=embedding_dim,
                hidden_dim=image_hidden_dim,
                dropout=dropout
            )
        else:
            # Cross-attention projector for full embeddings (2D sequences)
            self.image_projector = CrossAttentionImageProjector(
                input_dim=image_input_dim,
                output_dim=embedding_dim,
                hidden_dim=image_hidden_dim,
                dropout=dropout
            )
        
        self.text_projector = TextProjector(
            input_dim=text_input_dim,
            output_dim=embedding_dim,
            hidden_dim=text_hidden_dim,
            dropout=dropout
        )
        
        # Learnable logit scale parameter initialized to standard CLIP temperature 1/0.07
        # Using log parameterization for numerical stability
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07, dtype=torch.float32)))
        
    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Forward pass for CLIP training.
        
        Args:
            batch: Dictionary containing 'image_embedding' and 'text_embedding'
            
        Returns:
            Dictionary with projected embeddings and logits
        """
        image_features = batch['image_embedding']
        text_features = batch['text_embedding']

        # Project to shared space and normalize
        image_features = self.image_projector(image_features)
        text_features = self.text_projector(text_features)

        # Compute similarity matrix with learnable logit scale
        # Clamp after exp to preserve gradients
        logit_scale = self.logit_scale.exp().clamp(max=100)
        logits_per_image = logit_scale * image_features @ text_features.T
        logits_per_text = logits_per_image.T
        
        return {
            'image_features': image_features,
            'text_features': text_features,
            'logits_per_image': logits_per_image,
            'logits_per_text': logits_per_text,
            'logit_scale': logit_scale
        }
    
    def compute_contrastive_loss(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Compute contrastive loss (InfoNCE).
        
        Args:
            outputs: Model outputs from forward pass
            
        Returns:
            Contrastive loss
        """
        logits_per_image = outputs['logits_per_image']
        logits_per_text = outputs['logits_per_text']
        
        batch_size = logits_per_image.shape[0]
        labels = torch.arange(batch_size, device=logits_per_image.device)
        
        # Cross-entropy loss for both directions
        loss_i2t = F.cross_entropy(logits_per_image, labels)
        loss_t2i = F.cross_entropy(logits_per_text, labels)
        
        return (loss_i2t + loss_t2i) / 2