|
import math |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
class GeLU(nn.Module): |
|
def __init__(self) -> None: |
|
""" |
|
This is the gelu implementation from the original ESM repo. |
|
Using F.gelu yields subtly wrong results. |
|
""" |
|
super().__init__() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
@dataclass |
|
class RotaryEmbeddingConfig: |
|
""" |
|
Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
|
to adapt the rotary embeddings to larger lengths than what was used for training. |
|
One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
|
Args: |
|
""" |
|
|
|
rescaling_factor: Optional[float] |
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
""" |
|
Rotary position embeddings based on those in |
|
[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
|
Query and keys are transformed by rotation |
|
matrices which depend on their relative positions. |
|
""" |
|
|
|
def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): |
|
super().__init__() |
|
|
|
|
|
self.rescaling_factor = rotary_embedding_config.rescaling_factor |
|
self.upper_freq = 10000 |
|
self.dim = dim |
|
|
|
self._seq_len_cached = None |
|
self._cos_cached = None |
|
self._sin_cached = None |
|
|
|
def _apply_rotary_pos_emb( |
|
self, |
|
heads: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" """ |
|
x_first, x_second = ( |
|
heads[..., : heads.shape[-1] // 2], |
|
heads[..., heads.shape[-1] // 2 :], |
|
) |
|
|
|
first_part = x_first * cos - x_second * sin |
|
second_part = x_second * cos + x_first * sin |
|
|
|
return torch.cat((first_part, second_part), dim=-1) |
|
|
|
def _compute_cos_sin_tables( |
|
self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
seq_len = x.shape[seq_dimension] |
|
|
|
|
|
self._seq_len_cached = seq_len |
|
t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
|
|
|
freqs = torch.einsum("i, j -> ij", t, inv_freq) |
|
|
|
self._cos_cached = torch.cos(freqs)[None, :, None, :] |
|
self._sin_cached = torch.sin(freqs)[None, :, None, :] |
|
|
|
|
|
|
|
|
|
|
|
return self._cos_cached, self._sin_cached |
|
|
|
def forward( |
|
self, q: torch.Tensor, k: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.rescaling_factor is None: |
|
inv_freq = 1.0 / ( |
|
self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
) |
|
else: |
|
updated_base = self.upper_freq * ( |
|
self.rescaling_factor ** (self.dim / (self.dim - 2)) |
|
) |
|
inv_freq = 1.0 / ( |
|
updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
) |
|
|
|
self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
|
q, |
|
inv_freq, |
|
seq_dimension=-3, |
|
) |
|
|
|
return ( |
|
self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
|
self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
|
) |
|
|
|
|
|
class ResidualConvBlock(nn.Module): |
|
""" |
|
Conv Block with Residual connection. |
|
""" |
|
|
|
def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1): |
|
super().__init__() |
|
self.conv_block = ConvBlock( |
|
dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
y = self.conv_block(x) |
|
return x.reshape(y.shape) + y |
|
|
|
|
|
class ConvBlock(nn.Module): |
|
""" |
|
Conv Block. |
|
""" |
|
|
|
def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1): |
|
super().__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels=dim_in, |
|
out_channels=dim_out, |
|
kernel_size=kernel_size, |
|
padding="same", |
|
) |
|
self.layer_norm = nn.LayerNorm(seq_len, eps=1e-5) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.layer_norm(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = self.conv(x) |
|
x = F.gelu(x, approximate="tanh") |
|
return x |
|
|
|
|
|
class ResidualDeConvBlock(nn.Module): |
|
""" |
|
Conv Block with Residual connection. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
seq_len: int, |
|
kernel_size: int = 1, |
|
stride: int = 1, |
|
): |
|
super().__init__() |
|
self.deconv_block = DeConvBlock( |
|
dim_in=dim_in, |
|
dim_out=dim_out, |
|
seq_len=seq_len, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
y = self.deconv_block(x) |
|
return x.reshape(y.shape) + y |
|
|
|
|
|
class DeConvBlock(nn.Module): |
|
""" |
|
DeConv Block. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
seq_len: int, |
|
kernel_size: int = 1, |
|
stride: int = 1, |
|
): |
|
super().__init__() |
|
self.deconv = nn.ConvTranspose1d( |
|
in_channels=dim_in, |
|
out_channels=dim_out, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=0, |
|
) |
|
self.layer_norm = nn.LayerNorm(seq_len) |
|
self.kernel_size = kernel_size |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.layer_norm(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = self.deconv(x) |
|
if self.kernel_size == 5: |
|
|
|
|
|
x = x[:, :, 1:-2] |
|
x = F.gelu(x, approximate="tanh") |
|
return x |
|
|
|
|
|
class SpatialEncoding(nn.Module): |
|
""" |
|
Spatial coordinates encoding module |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_scales: int = 10, |
|
sigma_min: float = 1.0, |
|
sigma_max: float = 10.0, |
|
): |
|
super().__init__() |
|
self.num_scales = num_scales |
|
self.sigma_min = sigma_min |
|
self.sigma_max = sigma_max |
|
self.g = sigma_max / sigma_min |
|
self.scales = torch.linspace(sigma_min, sigma_max, num_scales) |
|
self.fc_layer = nn.Linear(embed_dim, embed_dim) |
|
|
|
def scale_specific_encoder( |
|
self, coordinates: torch.Tensor, scale: float |
|
) -> torch.Tensor: |
|
x, y = coordinates[..., 0], coordinates[..., 1] |
|
constant = self.sigma_min * (self.g ** (scale / (self.num_scales - 1))) |
|
x_transform = torch.cos(x / constant) |
|
y_transform = torch.sin(y / constant) |
|
transformed_coordinates = torch.stack([x_transform, y_transform], dim=-1) |
|
return transformed_coordinates |
|
|
|
def forward(self, coordinates: torch.Tensor) -> torch.Tensor: |
|
transformed_coordinates = [ |
|
self.scale_specific_encoder(coordinates, scale) for scale in self.scales |
|
] |
|
transformed_coordinates = torch.cat(transformed_coordinates, dim=-1) |
|
return self.fc_layer(transformed_coordinates) |
|
|
|
|
|
class ConvTowerBlock(nn.Module): |
|
def __init__( |
|
self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int, num_cells: int |
|
) -> None: |
|
super().__init__() |
|
self.conv_layer = ConvBlock( |
|
dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size |
|
) |
|
self.res_conv = ResidualConvBlock( |
|
dim_in=dim_out, dim_out=dim_out, seq_len=seq_len, kernel_size=1 |
|
) |
|
self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) |
|
self.num_cells = num_cells |
|
|
|
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
|
residual = x |
|
x = x.reshape(x.shape[0], x.shape[1], self.num_cells, -1) |
|
x = self.conv_layer(x) |
|
x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) |
|
x = self.res_conv(x) |
|
x = self.avg_pool(x) |
|
return x, residual |
|
|
|
|
|
class DeConvTowerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim_in: int, |
|
dim_out: int, |
|
kernel_size: int, |
|
seq_len: int, |
|
stride: int = 2, |
|
num_cells: int = 1, |
|
): |
|
super().__init__() |
|
self.deconv_block = DeConvBlock( |
|
dim_in=dim_in, |
|
dim_out=dim_out, |
|
seq_len=seq_len, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
) |
|
self.res_deconv_block = ResidualDeConvBlock( |
|
dim_in=dim_out, dim_out=dim_out, seq_len=seq_len * 2, kernel_size=1 |
|
) |
|
self.num_cells = num_cells |
|
|
|
def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: |
|
x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) |
|
x = self.deconv_block(x) |
|
x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) |
|
x = self.res_deconv_block(x) |
|
|
|
x = x + res |
|
return x |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
key_size: int, |
|
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
|
add_bias_kv: bool = False, |
|
value_size: Optional[int] = None, |
|
model_size: Optional[int] = None, |
|
name: Optional[str] = None, |
|
): |
|
super().__init__() |
|
if not model_size: |
|
model_size = key_size |
|
if not value_size: |
|
value_size = key_size |
|
self.model_size = model_size |
|
self.key_size = key_size |
|
self.value_size = value_size |
|
self.add_bias_kv = add_bias_kv |
|
self.name = name |
|
self.num_heads = num_heads |
|
self._rotary_embedding_config = rotary_embedding_config |
|
|
|
self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
|
self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
|
self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
|
if self._rotary_embedding_config: |
|
self._rotary_embedding = RotaryEmbedding( |
|
self.key_size, self._rotary_embedding_config |
|
) |
|
|
|
def apply_rotary_embeddings( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
""" """ |
|
query, key = self._rotary_embedding(query, key) |
|
return query, key |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> dict[str, torch.Tensor]: |
|
""" |
|
Returns: |
|
dictionary containing attention weights |
|
and outputs. |
|
""" |
|
key_heads = self.w_k(key).reshape( |
|
(*key.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
query_heads = self.w_q(query).reshape( |
|
(*query.shape[:-1], self.num_heads, self.key_size) |
|
) |
|
value_heads = self.w_v(value).reshape( |
|
(*value.shape[:-1], self.num_heads, self.value_size) |
|
) |
|
if self._rotary_embedding_config: |
|
query_heads, key_heads = self.apply_rotary_embeddings( |
|
query_heads, key_heads |
|
) |
|
attention_weights = torch.einsum( |
|
"...thd, ...Thd -> ...htT", query_heads, key_heads |
|
) |
|
sqrt_key_size = np.sqrt(self.key_size) |
|
attention_weights = attention_weights / sqrt_key_size |
|
if attention_mask: |
|
attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
|
if attention_weight_bias: |
|
attention_weights = F.softmax( |
|
attention_weights + attention_weight_bias, dim=-1 |
|
) |
|
else: |
|
attention_weights = F.softmax(attention_weights, dim=-1) |
|
value_out = torch.einsum( |
|
"...htT, ...Thd->...thd", attention_weights, value_heads |
|
) |
|
value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
|
embeddings = self.output(value_out) |
|
|
|
return {"attention_weights": attention_weights, "embeddings": embeddings} |
|
|
|
|
|
class SelfAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
num_heads: int, |
|
embed_dim: int, |
|
ffn_embed_dim: int, |
|
key_size: Optional[int] = None, |
|
add_bias_kv: bool = False, |
|
add_bias_fnn: bool = True, |
|
ffn_activation_name: str = "gelu-no-approx", |
|
use_glu_in_ffn: bool = False, |
|
layer_norm_eps: float = 1e-5, |
|
pre_layer_norm: bool = True, |
|
name: Optional[str] = None, |
|
rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
|
): |
|
super().__init__() |
|
if key_size is None: |
|
if embed_dim % num_heads != 0: |
|
raise ValueError( |
|
f"The embedding dimension should be divisible by the number of " |
|
f"heads, however provided embedding dimension is {embed_dim} and " |
|
f"the number of heads is {num_heads}." |
|
) |
|
else: |
|
key_size = embed_dim // num_heads |
|
|
|
|
|
self._pre_layer_norm = pre_layer_norm |
|
self._use_glu_in_fnn = use_glu_in_ffn |
|
|
|
if use_glu_in_ffn: |
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
|
else: |
|
self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
|
|
|
self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
|
|
|
self.layer_norm_self_attention = nn.LayerNorm( |
|
embed_dim, |
|
) |
|
self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
|
if ffn_activation_name == "swish": |
|
self._ffn_activation_fn = nn.SiLU() |
|
elif ffn_activation_name == "gelu-no-approx": |
|
self._ffn_activation_fn = nn.GeLU(approximate="tanh") |
|
else: |
|
self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
|
|
|
self.mha = MultiHeadAttention( |
|
num_heads=num_heads, |
|
key_size=key_size, |
|
add_bias_kv=add_bias_kv, |
|
model_size=embed_dim, |
|
name="self_attention", |
|
rotary_embedding_config=rotary_embedding_config, |
|
) |
|
|
|
def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
|
|
|
if self._pre_layer_norm: |
|
x = self.layer_norm_mlp(embed) |
|
else: |
|
x = embed |
|
|
|
if self._use_glu_in_fnn: |
|
x = self.fc1(x) |
|
x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
|
x = self._ffn_activation_fn(x1) * x2 |
|
else: |
|
x = self._ffn_activation_fn(self.fc1(x)) |
|
x = self.fc2(x) |
|
|
|
if not self._pre_layer_norm: |
|
x = self.layer_norm_mlp(x + embed) |
|
return x |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attention_weight_bias: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
|
|
res = x |
|
if self._pre_layer_norm: |
|
x = self.layer_norm_self_attention(x) |
|
|
|
output = self.mha( |
|
x, |
|
x, |
|
x, |
|
attention_mask=attention_mask, |
|
attention_weight_bias=attention_weight_bias, |
|
) |
|
|
|
if not self._pre_layer_norm: |
|
output["embeddings"] = self.layer_norm_self_attention( |
|
output["embeddings"] + res |
|
) |
|
|
|
x = output["embeddings"] |
|
else: |
|
x = output["embeddings"] |
|
x = res + x |
|
|
|
|
|
if not self._pre_layer_norm: |
|
x = self.mlp(x) |
|
else: |
|
x = x + self.mlp(x) |
|
|
|
output["embeddings"] = x |
|
return output |
|
|
|
|
|
class LMHead(nn.Module): |
|
def __init__( |
|
self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int |
|
) -> None: |
|
""" """ |
|
super().__init__() |
|
self.num_hidden_layers = num_hidden_layers |
|
self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) |
|
self.linear_layers.extend( |
|
nn.ModuleList( |
|
[nn.Linear(embed_dim, embed_dim)] for _ in range(num_hidden_layers - 1) |
|
) |
|
) |
|
self.linear_out = nn.Linear(embed_dim, dim_out) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
res = x |
|
x = F.gelu(x, approximate="tanh") |
|
for layer in self.linear_layers: |
|
x = layer(x) |
|
x = F.gelu(x, approximate="tanh") |
|
out = self.linear_out(x) |
|
return out |
|
|
|
|
|
@dataclass |
|
class sCTConfig(PretrainedConfig): |
|
model_type = "sCT" |
|
|
|
def __init__(self, **kwargs): |
|
self.alphabet_size = kwargs.get("alphabet_size", 7) |
|
self.pad_token_id = kwargs.get("pad_token_id", 5) |
|
self.mask_token_id = kwargs.get("mask_token_id", 6) |
|
self.cell_len = kwargs.get("cell_len", 19968) |
|
|
|
self.num_downsamples = kwargs.get("num_downsamples", 8) |
|
self.attention_heads = kwargs.get("attention_heads", 16) |
|
self.key_size = kwargs.get("key_size", None) |
|
self.token_embed_dim = kwargs.get("token_embed_dim", 16) |
|
|
|
self.embed_dim = kwargs.get("embed_dim", 1024) |
|
self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 2048) |
|
self.num_layers = kwargs.get("num_layers", 4) |
|
self.layer_norm_eps = kwargs.get("layer_norm_eps", 1e-5) |
|
self.interpolation_method = kwargs.get("interpolation_method", "nearest") |
|
|
|
|
|
self.max_positions: int = kwargs.get("max_positions", 20480) |
|
self.num_cells: int = kwargs.get("num_cells", 50) |
|
self.num_hidden_layers_head: int = kwargs.get("num_hidden_layers_head", 1) |
|
|
|
self.use_skip_connection: bool = kwargs.get("use_skip_connection", True) |
|
|
|
|
|
self.use_gradient_checkpointing: bool = False |
|
|
|
|
|
self.embeddings_layers_to_save: Tuple[int, ...] = kwargs.get( |
|
"embeddings_layers_to_save", () |
|
) |
|
self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( |
|
"attention_maps_to_save", [] |
|
) |
|
|
|
|
|
self.use_spatial_information: bool = kwargs.get( |
|
"use_spatial_information", False |
|
) |
|
self.num_scales: int = kwargs.get("num_scales", 10) |
|
self.sigma_min: float = kwargs.get("sigma_min", 1.0) |
|
self.sigma_max: float = kwargs.get("sigma_max", 10.0) |
|
|
|
super().__init__(**kwargs) |
|
|
|
def __post_init__(self) -> None: |
|
""" |
|
Checks that the given values are compatible. |
|
""" |
|
if self.key_size is None: |
|
if not self.embed_dim % self.attention_heads == 0: |
|
raise ValueError( |
|
f"When no key size is provided, the embedding dimension" |
|
f"should be divisible by the number of heads, however " |
|
f"provided embedding dimension is {self.embed_dim} and " |
|
f"the number of heads is {self.attention_heads}." |
|
) |
|
self.key_size = self.embed_dim // self.attention_heads |
|
|
|
|
|
class sCT(PreTrainedModel): |
|
config_class = sCTConfig |
|
|
|
def __init__(self, config: sCTConfig): |
|
|
|
super().__init__(config=config) |
|
if config.use_spatial_information: |
|
self.spatial_embed_layer = SpatialEncoding( |
|
embed_dim=config.token_embed_dim, |
|
num_scales=config.num_scales, |
|
sigma_min=config.sigma_min, |
|
sigma_max=config.sigma_max, |
|
) |
|
self.cell_len = config.cell_len |
|
|
|
self.token_embed = nn.Embedding(config.alphabet_size, config.token_embed_dim) |
|
|
|
attention_maps_to_save = config.attention_maps_to_save |
|
self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) |
|
|
|
self._attention_maps_per_layer_to_save = { |
|
layer: [t[1] for t in attention_maps_to_save if t[0] == layer] |
|
for layer in self._attention_layers_to_save |
|
} |
|
|
|
max_layer = max(self._attention_layers_to_save + [0]) |
|
if max_layer > config.num_layers: |
|
raise ValueError( |
|
f"You are requiring attention maps for layer {max_layer}, " |
|
f"while the model has {config.num_layers} layers only." |
|
) |
|
|
|
filter_list = np.linspace( |
|
config.token_embed_dim, |
|
config.embed_dim, |
|
config.num_downsamples + 1, |
|
) |
|
|
|
filter_list = np.ceil(filter_list / 32) * 32 |
|
filter_list = filter_list.astype(int).tolist() |
|
|
|
self._filter_list = filter_list |
|
self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) |
|
|
|
self.stem_conv = nn.Sequential( |
|
nn.Conv1d( |
|
in_channels=config.token_embed_dim, |
|
out_channels=config.token_embed_dim, |
|
kernel_size=15, |
|
padding="same", |
|
), |
|
nn.GELU(approximate="tanh"), |
|
) |
|
downsampled_seq_lens = [ |
|
self.cell_len // (2**i) for i in range(len(filter_list) - 1) |
|
] |
|
|
|
self.conv_tower = nn.ModuleList( |
|
[ |
|
ConvTowerBlock( |
|
dim_in=self._filter_list[i], |
|
dim_out=self._filter_list[i + 1], |
|
kernel_size=5, |
|
seq_len=seq_len, |
|
num_cells=config.num_cells, |
|
) |
|
for i, seq_len in zip(range(len(filter_list) - 1), downsampled_seq_lens) |
|
] |
|
) |
|
|
|
self.deconv_tower = nn.ModuleList( |
|
[ |
|
DeConvTowerBlock( |
|
dim_in=filter_list[-1 - i], |
|
dim_out=filter_list[-1 - i - 1], |
|
kernel_size=5, |
|
stride=2, |
|
seq_len=seq_len // 2, |
|
num_cells=config.num_cells, |
|
) |
|
for i, seq_len in zip( |
|
range(len(filter_list) - 1), downsampled_seq_lens[::-1] |
|
) |
|
] |
|
) |
|
self.transformer_layers = nn.ModuleList( |
|
[ |
|
SelfAttentionBlock( |
|
num_heads=config.attention_heads, |
|
embed_dim=config.embed_dim, |
|
ffn_embed_dim=config.ffn_embed_dim, |
|
key_size=config.key_size, |
|
add_bias_kv=False, |
|
add_bias_fnn=False, |
|
ffn_activation_name="swish", |
|
use_glu_in_ffn=True, |
|
layer_norm_eps=1e-5, |
|
pre_layer_norm=True, |
|
name=f"attention_layer_{layer_idx}", |
|
rotary_embedding_config=self._rotary_embedding_config, |
|
) |
|
for layer_idx in range(config.num_layers) |
|
] |
|
) |
|
|
|
self.lm_head = LMHead( |
|
dim_in=config.token_embed_dim, |
|
embed_dim=config.embed_dim, |
|
dim_out=config.alphabet_size, |
|
num_hidden_layers=config.num_hidden_layers_head, |
|
) |
|
|
|
def forward(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor]: |
|
outs = {} |
|
embeddings = self.token_embed(input_ids) |
|
x = embeddings.permute(0, 2, 1) |
|
x = self.stem_conv(x) |
|
residuals = [] |
|
for _idx, conv_block in enumerate(self.conv_tower): |
|
x, res = conv_block(x) |
|
residuals.append(res) |
|
residuals = residuals[::-1] |
|
x = x.permute(0, 2, 1) |
|
|
|
for layer_idx, transformer in enumerate(self.transformer_layers): |
|
output = transformer(x) |
|
x = output["embeddings"] |
|
if (layer_idx + 1) in self.config.embeddings_layers_to_save: |
|
outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] |
|
if (layer_idx + 1) in self._attention_layers_to_save: |
|
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: |
|
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" |
|
outs[dkey] = output["attention_weights"][:, map_number + 1] |
|
x = x.permute(0, 2, 1) |
|
for deconv_block, res in zip(self.deconv_tower, residuals): |
|
x = deconv_block(x, res) |
|
x = x.permute(0, 2, 1) |
|
logits = self.lm_head(x) |
|
outs["logits"] = logits |
|
|
|
return outs |
|
|