sCellTransformer / pytorch_sct.py
Yanisadel's picture
Upload sCT
b917499 verified
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from transformers import PretrainedConfig, PreTrainedModel
class GeLU(nn.Module):
def __init__(self) -> None:
"""
This is the gelu implementation from the original ESM repo.
Using F.gelu yields subtly wrong results.
"""
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
@dataclass
class RotaryEmbeddingConfig:
"""
Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
to adapt the rotary embeddings to larger lengths than what was used for training.
One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
Args:
"""
rescaling_factor: Optional[float]
class RotaryEmbedding(torch.nn.Module):
"""
Rotary position embeddings based on those in
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
Query and keys are transformed by rotation
matrices which depend on their relative positions.
"""
def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig):
super().__init__()
# Extract argument from the config
self.rescaling_factor = rotary_embedding_config.rescaling_factor
self.upper_freq = 10000
self.dim = dim
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _apply_rotary_pos_emb(
self,
heads: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
""" """
x_first, x_second = (
heads[..., : heads.shape[-1] // 2],
heads[..., heads.shape[-1] // 2 :],
)
first_part = x_first * cos - x_second * sin
second_part = x_second * cos + x_first * sin
return torch.cat((first_part, second_part), dim=-1)
def _compute_cos_sin_tables(
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
self._seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
# freqs = torch.outer(t, inv_freq)
freqs = torch.einsum("i, j -> ij", t, inv_freq)
self._cos_cached = torch.cos(freqs)[None, :, None, :]
self._sin_cached = torch.sin(freqs)[None, :, None, :]
# emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
# self._cos_cached = emb.cos()[None, None, :, :]
# self._sin_cached = emb.sin()[None, None, :, :]
return self._cos_cached, self._sin_cached
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.rescaling_factor is None:
inv_freq = 1.0 / (
self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
else:
updated_base = self.upper_freq * (
self.rescaling_factor ** (self.dim / (self.dim - 2))
)
inv_freq = 1.0 / (
updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
q,
inv_freq,
seq_dimension=-3,
)
return (
self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)
class ResidualConvBlock(nn.Module):
"""
Conv Block with Residual connection.
"""
def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1):
super().__init__()
self.conv_block = ConvBlock(
dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.conv_block(x)
return x.reshape(y.shape) + y
class ConvBlock(nn.Module):
"""
Conv Block.
"""
def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1):
super().__init__()
self.conv = nn.Conv1d(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=kernel_size,
padding="same",
)
self.layer_norm = nn.LayerNorm(seq_len, eps=1e-5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer_norm(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = self.conv(x)
x = F.gelu(x, approximate="tanh")
return x
class ResidualDeConvBlock(nn.Module):
"""
Conv Block with Residual connection.
"""
def __init__(
self,
dim_in: int,
dim_out: int,
seq_len: int,
kernel_size: int = 1,
stride: int = 1,
):
super().__init__()
self.deconv_block = DeConvBlock(
dim_in=dim_in,
dim_out=dim_out,
seq_len=seq_len,
kernel_size=kernel_size,
stride=stride,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.deconv_block(x)
return x.reshape(y.shape) + y
class DeConvBlock(nn.Module):
"""
DeConv Block.
"""
def __init__(
self,
dim_in: int,
dim_out: int,
seq_len: int,
kernel_size: int = 1,
stride: int = 1,
):
super().__init__()
self.deconv = nn.ConvTranspose1d(
in_channels=dim_in,
out_channels=dim_out,
kernel_size=kernel_size,
stride=stride,
padding=0,
)
self.layer_norm = nn.LayerNorm(seq_len)
self.kernel_size = kernel_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer_norm(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = self.deconv(x)
if self.kernel_size == 5:
# handle the special case where haiku
# deconv removes padding automatically
x = x[:, :, 1:-2]
x = F.gelu(x, approximate="tanh")
return x
class SpatialEncoding(nn.Module):
"""
Spatial coordinates encoding module
"""
def __init__(
self,
embed_dim: int,
num_scales: int = 10,
sigma_min: float = 1.0,
sigma_max: float = 10.0,
):
super().__init__()
self.num_scales = num_scales
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.g = sigma_max / sigma_min
self.scales = torch.linspace(sigma_min, sigma_max, num_scales)
self.fc_layer = nn.Linear(embed_dim, embed_dim)
def scale_specific_encoder(
self, coordinates: torch.Tensor, scale: float
) -> torch.Tensor:
x, y = coordinates[..., 0], coordinates[..., 1]
constant = self.sigma_min * (self.g ** (scale / (self.num_scales - 1)))
x_transform = torch.cos(x / constant)
y_transform = torch.sin(y / constant)
transformed_coordinates = torch.stack([x_transform, y_transform], dim=-1)
return transformed_coordinates
def forward(self, coordinates: torch.Tensor) -> torch.Tensor:
transformed_coordinates = [
self.scale_specific_encoder(coordinates, scale) for scale in self.scales
]
transformed_coordinates = torch.cat(transformed_coordinates, dim=-1)
return self.fc_layer(transformed_coordinates)
class ConvTowerBlock(nn.Module):
def __init__(
self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int, num_cells: int
) -> None:
super().__init__()
self.conv_layer = ConvBlock(
dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size
)
self.res_conv = ResidualConvBlock(
dim_in=dim_out, dim_out=dim_out, seq_len=seq_len, kernel_size=1
)
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
self.num_cells = num_cells
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
residual = x
x = x.reshape(x.shape[0], x.shape[1], self.num_cells, -1) # noqa: FKA100
x = self.conv_layer(x)
x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1))
x = self.res_conv(x)
x = self.avg_pool(x)
return x, residual
class DeConvTowerBlock(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
kernel_size: int,
seq_len: int,
stride: int = 2,
num_cells: int = 1,
):
super().__init__()
self.deconv_block = DeConvBlock(
dim_in=dim_in,
dim_out=dim_out,
seq_len=seq_len,
kernel_size=kernel_size,
stride=stride,
)
self.res_deconv_block = ResidualDeConvBlock(
dim_in=dim_out, dim_out=dim_out, seq_len=seq_len * 2, kernel_size=1
)
self.num_cells = num_cells
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1))
x = self.deconv_block(x)
x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1))
x = self.res_deconv_block(x)
x = x + res
return x
class MultiHeadAttention(nn.Module):
def __init__(
self,
num_heads: int,
key_size: int,
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None,
add_bias_kv: bool = False,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
name: Optional[str] = None,
):
super().__init__()
if not model_size:
model_size = key_size
if not value_size:
value_size = key_size
self.model_size = model_size
self.key_size = key_size
self.value_size = value_size
self.add_bias_kv = add_bias_kv
self.name = name
self.num_heads = num_heads
self._rotary_embedding_config = rotary_embedding_config
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
if self._rotary_embedding_config:
self._rotary_embedding = RotaryEmbedding(
self.key_size, self._rotary_embedding_config
)
def apply_rotary_embeddings(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
query, key = self._rotary_embedding(query, key)
return query, key
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_weight_bias: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""
Returns:
dictionary containing attention weights
and outputs.
"""
key_heads = self.w_k(key).reshape(
(*key.shape[:-1], self.num_heads, self.key_size)
)
query_heads = self.w_q(query).reshape(
(*query.shape[:-1], self.num_heads, self.key_size)
)
value_heads = self.w_v(value).reshape(
(*value.shape[:-1], self.num_heads, self.value_size)
)
if self._rotary_embedding_config:
query_heads, key_heads = self.apply_rotary_embeddings(
query_heads, key_heads
)
attention_weights = torch.einsum(
"...thd, ...Thd -> ...htT", query_heads, key_heads
)
sqrt_key_size = np.sqrt(self.key_size)
attention_weights = attention_weights / sqrt_key_size
if attention_mask:
attention_weights = torch.where(attention_mask, attention_weights, -1e30)
if attention_weight_bias:
attention_weights = F.softmax(
attention_weights + attention_weight_bias, dim=-1
)
else:
attention_weights = F.softmax(attention_weights, dim=-1)
value_out = torch.einsum(
"...htT, ...Thd->...thd", attention_weights, value_heads
)
value_out = value_out.reshape((*value_out.shape[:-2], -1))
embeddings = self.output(value_out)
return {"attention_weights": attention_weights, "embeddings": embeddings}
class SelfAttentionBlock(nn.Module):
def __init__(
self,
num_heads: int,
embed_dim: int,
ffn_embed_dim: int,
key_size: Optional[int] = None,
add_bias_kv: bool = False,
add_bias_fnn: bool = True,
ffn_activation_name: str = "gelu-no-approx",
use_glu_in_ffn: bool = False,
layer_norm_eps: float = 1e-5, # this is the default haiku value
pre_layer_norm: bool = True,
name: Optional[str] = None,
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None,
):
super().__init__()
if key_size is None:
if embed_dim % num_heads != 0:
raise ValueError(
f"The embedding dimension should be divisible by the number of "
f"heads, however provided embedding dimension is {embed_dim} and "
f"the number of heads is {num_heads}."
)
else:
key_size = embed_dim // num_heads
# Get ffn activation function
self._pre_layer_norm = pre_layer_norm
self._use_glu_in_fnn = use_glu_in_ffn
# Define layers
if use_glu_in_ffn:
# user should multiply ffn_embed_dim by 2/3 when using GLU
# to keep total number of parameters equal
# see https://arxiv.org/pdf/2002.05202.pdf. for more details
# we multiply by 2 here as the output will be split in 2 for GLU
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
else:
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
self.layer_norm_self_attention = nn.LayerNorm(
embed_dim,
)
self.layer_norm_mlp = nn.LayerNorm(embed_dim)
if ffn_activation_name == "swish":
self._ffn_activation_fn = nn.SiLU()
elif ffn_activation_name == "gelu-no-approx":
self._ffn_activation_fn = nn.GeLU(approximate="tanh")
else:
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
self.mha = MultiHeadAttention(
num_heads=num_heads,
key_size=key_size,
add_bias_kv=add_bias_kv,
model_size=embed_dim,
name="self_attention",
rotary_embedding_config=rotary_embedding_config,
)
def mlp(self, embed: torch.Tensor) -> torch.Tensor:
if self._pre_layer_norm:
x = self.layer_norm_mlp(embed)
else:
x = embed
if self._use_glu_in_fnn:
x = self.fc1(x)
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
x = self._ffn_activation_fn(x1) * x2
else:
x = self._ffn_activation_fn(self.fc1(x))
x = self.fc2(x)
if not self._pre_layer_norm:
x = self.layer_norm_mlp(x + embed)
return x
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_weight_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
res = x
if self._pre_layer_norm:
x = self.layer_norm_self_attention(x)
output = self.mha(
x,
x,
x,
attention_mask=attention_mask,
attention_weight_bias=attention_weight_bias,
)
if not self._pre_layer_norm:
output["embeddings"] = self.layer_norm_self_attention(
output["embeddings"] + res
)
x = output["embeddings"]
else:
x = output["embeddings"]
x = res + x
# MLP
if not self._pre_layer_norm:
x = self.mlp(x)
else:
x = x + self.mlp(x)
output["embeddings"] = x
return output
class LMHead(nn.Module):
def __init__(
self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int
) -> None:
""" """
super().__init__()
self.num_hidden_layers = num_hidden_layers
self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)])
self.linear_layers.extend(
nn.ModuleList(
[nn.Linear(embed_dim, embed_dim)] for _ in range(num_hidden_layers - 1)
)
)
self.linear_out = nn.Linear(embed_dim, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = x # noqa: F841
x = F.gelu(x, approximate="tanh")
for layer in self.linear_layers:
x = layer(x)
x = F.gelu(x, approximate="tanh")
out = self.linear_out(x)
return out
@dataclass
class sCTConfig(PretrainedConfig): # noqa: N801
model_type = "sCT"
def __init__(self, **kwargs): # type: ignore
self.alphabet_size = kwargs.get("alphabet_size", 7)
self.pad_token_id = kwargs.get("pad_token_id", 5)
self.mask_token_id = kwargs.get("mask_token_id", 6)
self.cell_len = kwargs.get("cell_len", 19968)
self.num_downsamples = kwargs.get("num_downsamples", 8)
self.attention_heads = kwargs.get("attention_heads", 16)
self.key_size = kwargs.get("key_size", None)
self.token_embed_dim = kwargs.get("token_embed_dim", 16)
self.embed_dim = kwargs.get("embed_dim", 1024)
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 2048)
self.num_layers = kwargs.get("num_layers", 4)
self.layer_norm_eps = kwargs.get("layer_norm_eps", 1e-5)
self.interpolation_method = kwargs.get("interpolation_method", "nearest")
# bad hack to satisfy cellnt_celltype_annotation.py:312
self.max_positions: int = kwargs.get("max_positions", 20480)
self.num_cells: int = kwargs.get("num_cells", 50)
self.num_hidden_layers_head: int = kwargs.get("num_hidden_layers_head", 1)
self.use_skip_connection: bool = kwargs.get("use_skip_connection", True)
# logging
self.use_gradient_checkpointing: bool = False
# return
self.embeddings_layers_to_save: Tuple[int, ...] = kwargs.get(
"embeddings_layers_to_save", ()
)
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get(
"attention_maps_to_save", []
)
# Spatial info configuration
self.use_spatial_information: bool = kwargs.get(
"use_spatial_information", False
)
self.num_scales: int = kwargs.get("num_scales", 10)
self.sigma_min: float = kwargs.get("sigma_min", 1.0)
self.sigma_max: float = kwargs.get("sigma_max", 10.0)
super().__init__(**kwargs)
def __post_init__(self) -> None: # type: ignore # noqa: N807
"""
Checks that the given values are compatible.
"""
if self.key_size is None:
if not self.embed_dim % self.attention_heads == 0:
raise ValueError(
f"When no key size is provided, the embedding dimension"
f"should be divisible by the number of heads, however "
f"provided embedding dimension is {self.embed_dim} and "
f"the number of heads is {self.attention_heads}."
)
self.key_size = self.embed_dim // self.attention_heads
class sCT(PreTrainedModel): # noqa: N801
config_class = sCTConfig
def __init__(self, config: sCTConfig):
# super().__init__(config)
super().__init__(config=config)
if config.use_spatial_information:
self.spatial_embed_layer = SpatialEncoding(
embed_dim=config.token_embed_dim,
num_scales=config.num_scales,
sigma_min=config.sigma_min,
sigma_max=config.sigma_max,
)
self.cell_len = config.cell_len
self.token_embed = nn.Embedding(config.alphabet_size, config.token_embed_dim)
attention_maps_to_save = config.attention_maps_to_save
self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save})
self._attention_maps_per_layer_to_save = {
layer: [t[1] for t in attention_maps_to_save if t[0] == layer]
for layer in self._attention_layers_to_save
}
max_layer = max(self._attention_layers_to_save + [0])
if max_layer > config.num_layers:
raise ValueError(
f"You are requiring attention maps for layer {max_layer}, "
f"while the model has {config.num_layers} layers only."
)
filter_list = np.linspace(
config.token_embed_dim,
config.embed_dim,
config.num_downsamples + 1,
)
filter_list = np.ceil(filter_list / 32) * 32
filter_list = filter_list.astype(int).tolist()
self._filter_list = filter_list
self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None)
self.stem_conv = nn.Sequential(
nn.Conv1d(
in_channels=config.token_embed_dim,
out_channels=config.token_embed_dim,
kernel_size=15,
padding="same",
),
nn.GELU(approximate="tanh"),
)
downsampled_seq_lens = [
self.cell_len // (2**i) for i in range(len(filter_list) - 1)
]
self.conv_tower = nn.ModuleList(
[
ConvTowerBlock(
dim_in=self._filter_list[i],
dim_out=self._filter_list[i + 1],
kernel_size=5,
seq_len=seq_len,
num_cells=config.num_cells,
)
for i, seq_len in zip(range(len(filter_list) - 1), downsampled_seq_lens)
]
)
self.deconv_tower = nn.ModuleList(
[
DeConvTowerBlock(
dim_in=filter_list[-1 - i],
dim_out=filter_list[-1 - i - 1],
kernel_size=5,
stride=2,
seq_len=seq_len // 2,
num_cells=config.num_cells,
)
for i, seq_len in zip(
range(len(filter_list) - 1), downsampled_seq_lens[::-1]
)
]
)
self.transformer_layers = nn.ModuleList(
[
SelfAttentionBlock(
num_heads=config.attention_heads,
embed_dim=config.embed_dim,
ffn_embed_dim=config.ffn_embed_dim,
key_size=config.key_size,
add_bias_kv=False,
add_bias_fnn=False,
ffn_activation_name="swish",
use_glu_in_ffn=True,
layer_norm_eps=1e-5, # this is the default haiku value
pre_layer_norm=True,
name=f"attention_layer_{layer_idx}",
rotary_embedding_config=self._rotary_embedding_config,
)
for layer_idx in range(config.num_layers)
]
)
self.lm_head = LMHead(
dim_in=config.token_embed_dim,
embed_dim=config.embed_dim,
dim_out=config.alphabet_size,
num_hidden_layers=config.num_hidden_layers_head,
)
def forward(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor]:
outs = {}
embeddings = self.token_embed(input_ids)
x = embeddings.permute(0, 2, 1)
x = self.stem_conv(x)
residuals = []
for _idx, conv_block in enumerate(self.conv_tower):
x, res = conv_block(x)
residuals.append(res)
residuals = residuals[::-1]
x = x.permute(0, 2, 1)
for layer_idx, transformer in enumerate(self.transformer_layers):
output = transformer(x)
x = output["embeddings"]
if (layer_idx + 1) in self.config.embeddings_layers_to_save:
outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"]
if (layer_idx + 1) in self._attention_layers_to_save:
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
outs[dkey] = output["attention_weights"][:, map_number + 1]
x = x.permute(0, 2, 1)
for deconv_block, res in zip(self.deconv_tower, residuals):
x = deconv_block(x, res)
x = x.permute(0, 2, 1)
logits = self.lm_head(x)
outs["logits"] = logits
return outs