import math from dataclasses import dataclass from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 from transformers import PretrainedConfig, PreTrainedModel class GeLU(nn.Module): def __init__(self) -> None: """ This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results. """ super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) @dataclass class RotaryEmbeddingConfig: """ Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows to adapt the rotary embeddings to larger lengths than what was used for training. One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa Args: """ rescaling_factor: Optional[float] class RotaryEmbedding(torch.nn.Module): """ Rotary position embeddings based on those in [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation matrices which depend on their relative positions. """ def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): super().__init__() # Extract argument from the config self.rescaling_factor = rotary_embedding_config.rescaling_factor self.upper_freq = 10000 self.dim = dim self._seq_len_cached = None self._cos_cached = None self._sin_cached = None def _apply_rotary_pos_emb( self, heads: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: """ """ x_first, x_second = ( heads[..., : heads.shape[-1] // 2], heads[..., heads.shape[-1] // 2 :], ) first_part = x_first * cos - x_second * sin second_part = x_second * cos + x_first * sin return torch.cat((first_part, second_part), dim=-1) def _compute_cos_sin_tables( self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 ) -> tuple[torch.Tensor, torch.Tensor]: seq_len = x.shape[seq_dimension] # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) self._seq_len_cached = seq_len t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) # freqs = torch.outer(t, inv_freq) freqs = torch.einsum("i, j -> ij", t, inv_freq) self._cos_cached = torch.cos(freqs)[None, :, None, :] self._sin_cached = torch.sin(freqs)[None, :, None, :] # emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # self._cos_cached = emb.cos()[None, None, :, :] # self._sin_cached = emb.sin()[None, None, :, :] return self._cos_cached, self._sin_cached def forward( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: if self.rescaling_factor is None: inv_freq = 1.0 / ( self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim) ) else: updated_base = self.upper_freq * ( self.rescaling_factor ** (self.dim / (self.dim - 2)) ) inv_freq = 1.0 / ( updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim) ) self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( q, inv_freq, seq_dimension=-3, ) return ( self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), ) class ResidualConvBlock(nn.Module): """ Conv Block with Residual connection. """ def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1): super().__init__() self.conv_block = ConvBlock( dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size ) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.conv_block(x) return x.reshape(y.shape) + y class ConvBlock(nn.Module): """ Conv Block. """ def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1): super().__init__() self.conv = nn.Conv1d( in_channels=dim_in, out_channels=dim_out, kernel_size=kernel_size, padding="same", ) self.layer_norm = nn.LayerNorm(seq_len, eps=1e-5) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_norm(x) x = x.reshape(x.shape[0], x.shape[1], -1) x = self.conv(x) x = F.gelu(x, approximate="tanh") return x class ResidualDeConvBlock(nn.Module): """ Conv Block with Residual connection. """ def __init__( self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1, stride: int = 1, ): super().__init__() self.deconv_block = DeConvBlock( dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size, stride=stride, ) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.deconv_block(x) return x.reshape(y.shape) + y class DeConvBlock(nn.Module): """ DeConv Block. """ def __init__( self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1, stride: int = 1, ): super().__init__() self.deconv = nn.ConvTranspose1d( in_channels=dim_in, out_channels=dim_out, kernel_size=kernel_size, stride=stride, padding=0, ) self.layer_norm = nn.LayerNorm(seq_len) self.kernel_size = kernel_size def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_norm(x) x = x.reshape(x.shape[0], x.shape[1], -1) x = self.deconv(x) if self.kernel_size == 5: # handle the special case where haiku # deconv removes padding automatically x = x[:, :, 1:-2] x = F.gelu(x, approximate="tanh") return x class SpatialEncoding(nn.Module): """ Spatial coordinates encoding module """ def __init__( self, embed_dim: int, num_scales: int = 10, sigma_min: float = 1.0, sigma_max: float = 10.0, ): super().__init__() self.num_scales = num_scales self.sigma_min = sigma_min self.sigma_max = sigma_max self.g = sigma_max / sigma_min self.scales = torch.linspace(sigma_min, sigma_max, num_scales) self.fc_layer = nn.Linear(embed_dim, embed_dim) def scale_specific_encoder( self, coordinates: torch.Tensor, scale: float ) -> torch.Tensor: x, y = coordinates[..., 0], coordinates[..., 1] constant = self.sigma_min * (self.g ** (scale / (self.num_scales - 1))) x_transform = torch.cos(x / constant) y_transform = torch.sin(y / constant) transformed_coordinates = torch.stack([x_transform, y_transform], dim=-1) return transformed_coordinates def forward(self, coordinates: torch.Tensor) -> torch.Tensor: transformed_coordinates = [ self.scale_specific_encoder(coordinates, scale) for scale in self.scales ] transformed_coordinates = torch.cat(transformed_coordinates, dim=-1) return self.fc_layer(transformed_coordinates) class ConvTowerBlock(nn.Module): def __init__( self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int, num_cells: int ) -> None: super().__init__() self.conv_layer = ConvBlock( dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size ) self.res_conv = ResidualConvBlock( dim_in=dim_out, dim_out=dim_out, seq_len=seq_len, kernel_size=1 ) self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) self.num_cells = num_cells def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: residual = x x = x.reshape(x.shape[0], x.shape[1], self.num_cells, -1) # noqa: FKA100 x = self.conv_layer(x) x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) x = self.res_conv(x) x = self.avg_pool(x) return x, residual class DeConvTowerBlock(nn.Module): def __init__( self, dim_in: int, dim_out: int, kernel_size: int, seq_len: int, stride: int = 2, num_cells: int = 1, ): super().__init__() self.deconv_block = DeConvBlock( dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size, stride=stride, ) self.res_deconv_block = ResidualDeConvBlock( dim_in=dim_out, dim_out=dim_out, seq_len=seq_len * 2, kernel_size=1 ) self.num_cells = num_cells def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) x = self.deconv_block(x) x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) x = self.res_deconv_block(x) x = x + res return x class MultiHeadAttention(nn.Module): def __init__( self, num_heads: int, key_size: int, rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, add_bias_kv: bool = False, value_size: Optional[int] = None, model_size: Optional[int] = None, name: Optional[str] = None, ): super().__init__() if not model_size: model_size = key_size if not value_size: value_size = key_size self.model_size = model_size self.key_size = key_size self.value_size = value_size self.add_bias_kv = add_bias_kv self.name = name self.num_heads = num_heads self._rotary_embedding_config = rotary_embedding_config self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) if self._rotary_embedding_config: self._rotary_embedding = RotaryEmbedding( self.key_size, self._rotary_embedding_config ) def apply_rotary_embeddings( self, query: torch.Tensor, key: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ """ query, key = self._rotary_embedding(query, key) return query, key def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attention_weight_bias: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """ Returns: dictionary containing attention weights and outputs. """ key_heads = self.w_k(key).reshape( (*key.shape[:-1], self.num_heads, self.key_size) ) query_heads = self.w_q(query).reshape( (*query.shape[:-1], self.num_heads, self.key_size) ) value_heads = self.w_v(value).reshape( (*value.shape[:-1], self.num_heads, self.value_size) ) if self._rotary_embedding_config: query_heads, key_heads = self.apply_rotary_embeddings( query_heads, key_heads ) attention_weights = torch.einsum( "...thd, ...Thd -> ...htT", query_heads, key_heads ) sqrt_key_size = np.sqrt(self.key_size) attention_weights = attention_weights / sqrt_key_size if attention_mask: attention_weights = torch.where(attention_mask, attention_weights, -1e30) if attention_weight_bias: attention_weights = F.softmax( attention_weights + attention_weight_bias, dim=-1 ) else: attention_weights = F.softmax(attention_weights, dim=-1) value_out = torch.einsum( "...htT, ...Thd->...thd", attention_weights, value_heads ) value_out = value_out.reshape((*value_out.shape[:-2], -1)) embeddings = self.output(value_out) return {"attention_weights": attention_weights, "embeddings": embeddings} class SelfAttentionBlock(nn.Module): def __init__( self, num_heads: int, embed_dim: int, ffn_embed_dim: int, key_size: Optional[int] = None, add_bias_kv: bool = False, add_bias_fnn: bool = True, ffn_activation_name: str = "gelu-no-approx", use_glu_in_ffn: bool = False, layer_norm_eps: float = 1e-5, # this is the default haiku value pre_layer_norm: bool = True, name: Optional[str] = None, rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, ): super().__init__() if key_size is None: if embed_dim % num_heads != 0: raise ValueError( f"The embedding dimension should be divisible by the number of " f"heads, however provided embedding dimension is {embed_dim} and " f"the number of heads is {num_heads}." ) else: key_size = embed_dim // num_heads # Get ffn activation function self._pre_layer_norm = pre_layer_norm self._use_glu_in_fnn = use_glu_in_ffn # Define layers if use_glu_in_ffn: # user should multiply ffn_embed_dim by 2/3 when using GLU # to keep total number of parameters equal # see https://arxiv.org/pdf/2002.05202.pdf. for more details # we multiply by 2 here as the output will be split in 2 for GLU self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) else: self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) self.layer_norm_self_attention = nn.LayerNorm( embed_dim, ) self.layer_norm_mlp = nn.LayerNorm(embed_dim) if ffn_activation_name == "swish": self._ffn_activation_fn = nn.SiLU() elif ffn_activation_name == "gelu-no-approx": self._ffn_activation_fn = nn.GeLU(approximate="tanh") else: self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) self.mha = MultiHeadAttention( num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv, model_size=embed_dim, name="self_attention", rotary_embedding_config=rotary_embedding_config, ) def mlp(self, embed: torch.Tensor) -> torch.Tensor: if self._pre_layer_norm: x = self.layer_norm_mlp(embed) else: x = embed if self._use_glu_in_fnn: x = self.fc1(x) x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) x = self._ffn_activation_fn(x1) * x2 else: x = self._ffn_activation_fn(self.fc1(x)) x = self.fc2(x) if not self._pre_layer_norm: x = self.layer_norm_mlp(x + embed) return x def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attention_weight_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: res = x if self._pre_layer_norm: x = self.layer_norm_self_attention(x) output = self.mha( x, x, x, attention_mask=attention_mask, attention_weight_bias=attention_weight_bias, ) if not self._pre_layer_norm: output["embeddings"] = self.layer_norm_self_attention( output["embeddings"] + res ) x = output["embeddings"] else: x = output["embeddings"] x = res + x # MLP if not self._pre_layer_norm: x = self.mlp(x) else: x = x + self.mlp(x) output["embeddings"] = x return output class LMHead(nn.Module): def __init__( self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int ) -> None: """ """ super().__init__() self.num_hidden_layers = num_hidden_layers self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) self.linear_layers.extend( nn.ModuleList( [nn.Linear(embed_dim, embed_dim)] for _ in range(num_hidden_layers - 1) ) ) self.linear_out = nn.Linear(embed_dim, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: res = x # noqa: F841 x = F.gelu(x, approximate="tanh") for layer in self.linear_layers: x = layer(x) x = F.gelu(x, approximate="tanh") out = self.linear_out(x) return out @dataclass class sCTConfig(PretrainedConfig): # noqa: N801 model_type = "sCT" def __init__(self, **kwargs): # type: ignore self.alphabet_size = kwargs.get("alphabet_size", 7) self.pad_token_id = kwargs.get("pad_token_id", 5) self.mask_token_id = kwargs.get("mask_token_id", 6) self.cell_len = kwargs.get("cell_len", 19968) self.num_downsamples = kwargs.get("num_downsamples", 8) self.attention_heads = kwargs.get("attention_heads", 16) self.key_size = kwargs.get("key_size", None) self.token_embed_dim = kwargs.get("token_embed_dim", 16) self.embed_dim = kwargs.get("embed_dim", 1024) self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 2048) self.num_layers = kwargs.get("num_layers", 4) self.layer_norm_eps = kwargs.get("layer_norm_eps", 1e-5) self.interpolation_method = kwargs.get("interpolation_method", "nearest") # bad hack to satisfy cellnt_celltype_annotation.py:312 self.max_positions: int = kwargs.get("max_positions", 20480) self.num_cells: int = kwargs.get("num_cells", 50) self.num_hidden_layers_head: int = kwargs.get("num_hidden_layers_head", 1) self.use_skip_connection: bool = kwargs.get("use_skip_connection", True) # logging self.use_gradient_checkpointing: bool = False # return self.embeddings_layers_to_save: Tuple[int, ...] = kwargs.get( "embeddings_layers_to_save", () ) self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( "attention_maps_to_save", [] ) # Spatial info configuration self.use_spatial_information: bool = kwargs.get( "use_spatial_information", False ) self.num_scales: int = kwargs.get("num_scales", 10) self.sigma_min: float = kwargs.get("sigma_min", 1.0) self.sigma_max: float = kwargs.get("sigma_max", 10.0) super().__init__(**kwargs) def __post_init__(self) -> None: # type: ignore # noqa: N807 """ Checks that the given values are compatible. """ if self.key_size is None: if not self.embed_dim % self.attention_heads == 0: raise ValueError( f"When no key size is provided, the embedding dimension" f"should be divisible by the number of heads, however " f"provided embedding dimension is {self.embed_dim} and " f"the number of heads is {self.attention_heads}." ) self.key_size = self.embed_dim // self.attention_heads class sCT(PreTrainedModel): # noqa: N801 config_class = sCTConfig def __init__(self, config: sCTConfig): # super().__init__(config) super().__init__(config=config) if config.use_spatial_information: self.spatial_embed_layer = SpatialEncoding( embed_dim=config.token_embed_dim, num_scales=config.num_scales, sigma_min=config.sigma_min, sigma_max=config.sigma_max, ) self.cell_len = config.cell_len self.token_embed = nn.Embedding(config.alphabet_size, config.token_embed_dim) attention_maps_to_save = config.attention_maps_to_save self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) self._attention_maps_per_layer_to_save = { layer: [t[1] for t in attention_maps_to_save if t[0] == layer] for layer in self._attention_layers_to_save } max_layer = max(self._attention_layers_to_save + [0]) if max_layer > config.num_layers: raise ValueError( f"You are requiring attention maps for layer {max_layer}, " f"while the model has {config.num_layers} layers only." ) filter_list = np.linspace( config.token_embed_dim, config.embed_dim, config.num_downsamples + 1, ) filter_list = np.ceil(filter_list / 32) * 32 filter_list = filter_list.astype(int).tolist() self._filter_list = filter_list self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) self.stem_conv = nn.Sequential( nn.Conv1d( in_channels=config.token_embed_dim, out_channels=config.token_embed_dim, kernel_size=15, padding="same", ), nn.GELU(approximate="tanh"), ) downsampled_seq_lens = [ self.cell_len // (2**i) for i in range(len(filter_list) - 1) ] self.conv_tower = nn.ModuleList( [ ConvTowerBlock( dim_in=self._filter_list[i], dim_out=self._filter_list[i + 1], kernel_size=5, seq_len=seq_len, num_cells=config.num_cells, ) for i, seq_len in zip(range(len(filter_list) - 1), downsampled_seq_lens) ] ) self.deconv_tower = nn.ModuleList( [ DeConvTowerBlock( dim_in=filter_list[-1 - i], dim_out=filter_list[-1 - i - 1], kernel_size=5, stride=2, seq_len=seq_len // 2, num_cells=config.num_cells, ) for i, seq_len in zip( range(len(filter_list) - 1), downsampled_seq_lens[::-1] ) ] ) self.transformer_layers = nn.ModuleList( [ SelfAttentionBlock( num_heads=config.attention_heads, embed_dim=config.embed_dim, ffn_embed_dim=config.ffn_embed_dim, key_size=config.key_size, add_bias_kv=False, add_bias_fnn=False, ffn_activation_name="swish", use_glu_in_ffn=True, layer_norm_eps=1e-5, # this is the default haiku value pre_layer_norm=True, name=f"attention_layer_{layer_idx}", rotary_embedding_config=self._rotary_embedding_config, ) for layer_idx in range(config.num_layers) ] ) self.lm_head = LMHead( dim_in=config.token_embed_dim, embed_dim=config.embed_dim, dim_out=config.alphabet_size, num_hidden_layers=config.num_hidden_layers_head, ) def forward(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor]: outs = {} embeddings = self.token_embed(input_ids) x = embeddings.permute(0, 2, 1) x = self.stem_conv(x) residuals = [] for _idx, conv_block in enumerate(self.conv_tower): x, res = conv_block(x) residuals.append(res) residuals = residuals[::-1] x = x.permute(0, 2, 1) for layer_idx, transformer in enumerate(self.transformer_layers): output = transformer(x) x = output["embeddings"] if (layer_idx + 1) in self.config.embeddings_layers_to_save: outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] if (layer_idx + 1) in self._attention_layers_to_save: for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" outs[dkey] = output["attention_weights"][:, map_number + 1] x = x.permute(0, 2, 1) for deconv_block, res in zip(self.deconv_tower, residuals): x = deconv_block(x, res) x = x.permute(0, 2, 1) logits = self.lm_head(x) outs["logits"] = logits return outs