| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| import torch |
| from torch import nn |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| try: |
| from .configuration_f2p_decoder import F2PDecoderConfig |
| from .decoder import GeneralDecoder |
| except ImportError: |
| from configuration_f2p_decoder import F2PDecoderConfig |
| from decoder import GeneralDecoder |
|
|
|
|
| @dataclass |
| class F2PDecoderOutput(ModelOutput): |
| reconstruction: torch.FloatTensor = None |
| logits: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
| class F2PDecoderModel(PreTrainedModel): |
| """Feature-to-pixel decoder for SigLIP2 patch features.""" |
|
|
| config_class = F2PDecoderConfig |
| base_model_prefix = "f2p_decoder" |
| main_input_name = "hidden_states" |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: F2PDecoderConfig): |
| super().__init__(config) |
| image_mean = torch.tensor(config.image_mean, dtype=torch.float32).view( |
| 1, config.num_channels, 1, 1 |
| ) |
| image_std = torch.tensor(config.image_std, dtype=torch.float32).view( |
| 1, config.num_channels, 1, 1 |
| ) |
| self.register_buffer("image_mean", image_mean) |
| self.register_buffer("image_std", image_std) |
| self.decoder = GeneralDecoder(config, num_patches=config.num_patches) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, GeneralDecoder): |
| module.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| hidden_states: Optional[torch.Tensor] = None, |
| zs: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| if hidden_states is None: |
| hidden_states = zs |
| if hidden_states is None: |
| raise ValueError("Pass SigLIP2 features as hidden_states or zs.") |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| decoder_output = self.decoder( |
| hidden_states, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| drop_cls_token=self.config.drop_cls_token, |
| ) |
| reconstruction = self.decoder.unpatchify(decoder_output.logits) |
| reconstruction = reconstruction * self.image_std + self.image_mean |
|
|
| if return_dict: |
| return F2PDecoderOutput( |
| reconstruction=reconstruction, |
| logits=decoder_output.logits, |
| hidden_states=decoder_output.hidden_states, |
| attentions=decoder_output.attentions, |
| ) |
| return reconstruction |
|
|
| @torch.no_grad() |
| def infer(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.forward(hidden_states, return_dict=False) |
|
|