| | from PIL import Image |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import transformers |
| | from transformers import PreTrainedModel |
| |
|
| | from src import loss |
| | from src import vision_model |
| | from src.config import TinyCLIPConfig |
| | from src.config import TinyCLIPTextConfig |
| | from src.config import TinyCLIPVisionConfig |
| |
|
| |
|
| | class Projection(nn.Module): |
| | def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None: |
| | super().__init__() |
| | self.linear1 = nn.Linear(d_in, d_out, bias=False) |
| | self.linear2 = nn.Linear(d_out, d_out, bias=False) |
| | self.layer_norm = nn.LayerNorm(d_out) |
| | self.drop = nn.Dropout(p) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | embed1 = self.linear1(x) |
| | embed2 = self.drop(self.linear2(F.gelu(embed1))) |
| | embeds = self.layer_norm(embed1 + embed2) |
| | return embeds |
| |
|
| |
|
| | def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module: |
| | layers = [] |
| | for _ in range(num_layers - 1): |
| | layers.extend([Projection(d_in, d_in), nn.GELU()]) |
| | layers += [Projection(d_in, d_out)] |
| | return nn.Sequential(*layers) |
| |
|
| |
|
| | def mean_pooling( |
| | text_representation: torch.FloatTensor, attention_mask: torch.LongTensor |
| | ) -> torch.FloatTensor: |
| | input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float() |
| | return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp( |
| | input_mask_expanded.sum(1), min=1e-9 |
| | ) |
| |
|
| |
|
| | class TinyCLIPTextEncoder(PreTrainedModel): |
| | config_class = TinyCLIPTextConfig |
| |
|
| | def __init__(self, config: TinyCLIPTextConfig): |
| | super().__init__(config) |
| | self.base = transformers.AutoModel.from_pretrained(config.text_model) |
| | self.cls_type = config.cls_type |
| | self.projection = projection_layers( |
| | self.base.config.hidden_size, config.embed_dims, config.projection_layers |
| | ) |
| |
|
| | def forward(self, x: dict[str, torch.Tensor]): |
| | out = self.base(**x).last_hidden_state |
| | if self.cls_type: |
| | out = out[:, 0] |
| | else: |
| | out = mean_pooling(out, x["attention_mask"]) |
| |
|
| | projected_vec = self.projection(out) |
| | return F.normalize(projected_vec, dim=-1) |
| |
|
| |
|
| | class TinyCLIPVisionEncoder(PreTrainedModel): |
| | config_class = TinyCLIPVisionConfig |
| |
|
| | def __init__(self, config: TinyCLIPVisionConfig): |
| | super().__init__(config) |
| | base, num_features = vision_model.get_vision_base(config) |
| | self.base = base |
| | self.projection = projection_layers( |
| | num_features, config.embed_dims, config.projection_layers |
| | ) |
| |
|
| | def forward(self, images: torch.Tensor): |
| | projected_vec = self.projection(self.base(images)) |
| | return F.normalize(projected_vec, dim=-1) |
| |
|
| |
|
| | class TinyCLIP(PreTrainedModel): |
| | config_class = TinyCLIPConfig |
| |
|
| | def __init__(self, config: TinyCLIPConfig): |
| | super().__init__(config) |
| | self.text_encoder = TinyCLIPTextEncoder(config.text_config) |
| | self.vision_encoder = TinyCLIPVisionEncoder(config.vision_config) |
| |
|
| | if config.freeze_text_base: |
| | self.text_encoder.base.eval() |
| | for param in self.text_encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | if config.freeze_vision_base: |
| | self.vision_encoder.base.eval() |
| | for param in self.vision_encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.loss_fn = loss.get_loss(config.loss_type) |
| |
|
| | def forward( |
| | self, |
| | text_input: dict[str, torch.Tensor], |
| | vision_input: list[Image.Image], |
| | return_loss: bool = False, |
| | ) -> dict[str, torch.Tensor]: |
| | text_output = self.text_encoder(text_input) |
| | vision_output = self.vision_encoder(vision_input) |
| |
|
| | out = {"text_output": text_output, "vision_output": vision_output} |
| |
|
| | if return_loss: |
| | out["loss"] = self.loss_fn(vision_output, text_output) |
| |
|
| | return out |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = TinyCLIP(TinyCLIPConfig()) |
| | print(model) |
| | print("Done!") |
| |
|