| """ |
| Lightning Module for the binding model. |
| """ |
|
|
| import torch |
| from torch import nn |
| from lightning import LightningModule |
| from dpacman.utils.models import set_seed |
| from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits |
|
|
| set_seed() |
|
|
| class LocalCNN(nn.Module): |
| def __init__(self, dim: int = 256, kernel_size: int = 3, dropout=0.1): |
| super().__init__() |
| padding = kernel_size // 2 |
| self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding) |
| self.act = nn.GELU() |
| self.ln = nn.LayerNorm(dim) |
| |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| out = self.conv(x.transpose(1, 2)) |
| out = self.act(out) |
| out = self.dropout(out) |
| out = out.transpose(1, 2) |
| return self.ln(out + x) |
|
|
|
|
| class CrossModalBlock(nn.Module): |
| def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.1): |
| super().__init__() |
| |
| self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout) |
| self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout) |
| self.do_sa_b = nn.Dropout(dropout) |
| self.do_sa_g = nn.Dropout(dropout) |
| |
| self.ln_b1 = nn.LayerNorm(dim) |
| self.ln_g1 = nn.LayerNorm(dim) |
| |
| self.ffn_b1 = nn.Sequential( |
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim) |
| ) |
| self.ffn_g1 = nn.Sequential( |
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim) |
| ) |
| self.do_ffn_b1 = nn.Dropout(dropout) |
| self.do_ffn_g1 = nn.Dropout(dropout) |
| |
| self.ln_b2 = nn.LayerNorm(dim) |
| self.ln_g2 = nn.LayerNorm(dim) |
| |
| |
| |
| self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout) |
| self.do_rca_g = nn.Dropout(dropout) |
| self.ln_g3_RCA = nn.LayerNorm(dim) |
| self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim)) |
| self.do_ffn_g2 = nn.Dropout(dropout) |
| self.ln_g4_RCA = nn.LayerNorm(dim) |
|
|
| |
| self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout) |
| self.do_rca_b = nn.Dropout(dropout) |
| self.ln_b3_RCA = nn.LayerNorm(dim) |
| self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim)) |
| self.do_ffn_b2 = nn.Dropout(dropout) |
| self.ln_b4_RCA = nn.LayerNorm(dim) |
|
|
| |
| |
| self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True) |
| self.do_g2b2 = nn.Dropout(dropout) |
| self.ln_g5 = nn.LayerNorm(dim) |
| self.ffn_g3 = nn.Sequential( |
| nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim) |
| ) |
| self.do_ffn_g3 = nn.Dropout(dropout) |
| self.ln_g6 = nn.LayerNorm(dim) |
|
|
| def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None): |
| """ |
| binder: (batch, Lb, dim) |
| glm: (batch, Lg, dim) -- has passed through its local CNN beforehand |
| returns: updated binder representation (batch, Lb, dim) and gLM representation |
| """ |
| |
| |
| b = binder |
| b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask) |
| b = self.ln_b1(b + self.do_sa_b(b_sa)) |
| b_ff = self.ffn_b1(b) |
| b = self.ln_b2(b + self.do_ffn_b1(b_ff)) |
|
|
| |
| g = glm |
| g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask) |
| g = self.ln_g1(g + self.do_sa_g(g_sa)) |
| g_ff = self.ffn_g1(g) |
| g = self.ln_g2(g + self.do_ffn_g1(g_ff)) |
| |
| |
| |
| |
| g_ca, _ = self.cross_g2b_1_RCA( |
| g, b, b, key_padding_mask=binder_kpm_mask |
| |
| |
| |
| ) |
| g = self.ln_g3_RCA(g + self.do_rca_g(g_ca)) |
| g = self.ln_g4_RCA(g + self.do_ffn_g2(self.ffn_g2_RCA(g))) |
|
|
| |
| b_ca, _ = self.cross_b2g_1_RCA( |
| b, g, g, key_padding_mask=glm_kpm_mask |
| |
| ) |
| b = self.ln_b3_RCA(b + self.do_rca_b(b_ca)) |
| b = self.ln_b4_RCA(b + self.do_ffn_b2(self.ffn_b2_RCA(b))) |
|
|
| |
| g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask) |
| g = self.ln_g5(g + self.do_g2b2(g_to_b_ca)) |
| g_ff = self.ffn_g3(g) |
| g = self.ln_g6(g + self.do_ffn_g3(g_ff)) |
| return b, g |
|
|
| class DimCompressor(nn.Module): |
| """ |
| Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256). |
| If in_dim == out_dim, behaves as identity. |
| """ |
|
|
| def __init__(self, in_dim: int, out_dim: int = 256): |
| super().__init__() |
| if in_dim == out_dim: |
| self.net = nn.Identity() |
| else: |
| hidden = max(out_dim * 2, (in_dim + out_dim) // 2) |
| self.net = nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, out_dim), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return self.net(x) |
|
|
|
|
| class BindPredictor(LightningModule): |
| def __init__( |
| self, |
| |
| binder_input_dim: int = 1280, |
| glm_input_dim: int = 256, |
| compressed_dim: int = 256, |
| hidden_dim: int = 256, |
| heads: int = 8, |
| num_layers: int = 4, |
| lr: float = 1e-4, |
| alpha: float = 20, |
| gamma: float = 20, |
| dropout: float = 0, |
| use_local_cnn_on_glm: bool = True, |
| weight_decay: float = 0.01, |
| loss_type = "mixed" |
| ): |
| |
| super(BindPredictor, self).__init__() |
| self.save_hyperparameters() |
|
|
| |
| self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim) |
| self.proj_binder = nn.Linear(compressed_dim, hidden_dim) |
| self.dropout_b1 = nn.Dropout(dropout) |
| self.act = nn.GELU() |
|
|
| |
| self.proj_glm = nn.Linear(glm_input_dim, hidden_dim) |
| self.dropout_g1 = nn.Dropout(dropout) |
| |
| self.use_local_cnn = use_local_cnn_on_glm |
| self.local_cnn = LocalCNN(hidden_dim, dropout=self.hparams.dropout) if use_local_cnn_on_glm else nn.Identity() |
|
|
| self.layers = nn.ModuleList( |
| [CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)] |
| ) |
|
|
| |
| |
| self.head = nn.Linear(hidden_dim, 1) |
|
|
| def forward(self, binder_emb, glm_emb, binder_mask, glm_mask): |
| """ |
| binder_emb: (B, Lb, binder_input_dim) |
| glm_emb: (B, Lg, glm_input_dim) |
| Returns per-nucleotide logits for the GLM sequence: (B, Lg) |
| """ |
| |
| b = self.binder_compress(binder_emb) |
| b = self.proj_binder(b) |
| b = self.dropout_b1(self.act(b)) |
|
|
| |
| g = self.proj_glm(glm_emb) |
| g = self.dropout_g1(self.act(g)) |
| if self.use_local_cnn: |
| g = self.local_cnn(g) |
|
|
| |
| for layer in self.layers: |
| b, g = layer(b, g, binder_mask, glm_mask) |
|
|
| |
| |
| logits = self.head(g).squeeze( |
| -1 |
| ) |
| return logits |
| |
| |
| def training_step(self, batch, batch_idx): |
| """ |
| Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator. |
| Colator returns a dictionary with: |
| "binder_emb" # [B, Lb_max, Db] |
| "binder_kpm" # [B, Lb_max] |
| "glm_emb" # [B, Lg_max, Dg] |
| "glm_kpm" # [B, Lg_max] |
| "labels" # [B, Lg_max] |
| "ID" |
| "tr_sequence" |
| "dna_sequence" |
| } |
| """ |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) |
| loss = calculate_loss( |
| logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type |
| ) |
| self.log( |
| "train/loss", |
| loss, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=logits.size(0), |
| ) |
| |
| |
| ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| |
| self.log("train/auprc_0v1", |
| ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| self.log("train/auroc_0v1", |
| auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| |
| self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True) |
| self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True) |
|
|
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) |
| loss = calculate_loss( |
| logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type |
| ) |
| self.log( |
| "val/loss", |
| loss, |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=logits.size(0), |
| ) |
| |
| |
| ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| |
| self.log("val/auprc_0v1", |
| ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| self.log("val/auroc_0v1", |
| auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| |
| return loss |
|
|
| def test_step(self, batch, batch_idx): |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) |
| loss = calculate_loss( |
| logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type |
| ) |
| self.log( |
| "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0) |
| ) |
| |
| |
| ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| |
| self.log("test/auprc_0v1", |
| ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| self.log("test/auroc_0v1", |
| auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| |
| return loss |
|
|
| def predict_step(self, batch, batch_idx, dataloader_idx=0): |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], |
| batch["binder_kpm"], batch["glm_kpm"]).squeeze(-1) |
| valid = ~batch["glm_kpm"] |
| return { |
| "ids": batch["ID"], |
| "logits": logits.detach().cpu(), |
| "valid": valid.detach().cpu(), |
| "labels": batch["labels"].detach().cpu(), |
| } |
| |
| def on_before_optimizer_step(self, optimizer): |
| |
| grads = [] |
| for p in self.parameters(): |
| if p.grad is not None: |
| |
| grads.append(p.grad.detach().float().norm(2)) |
| if grads: |
| total_norm = torch.norm(torch.stack(grads), p=2) |
| self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True) |
| |
| def on_after_backward(self): |
| grads = [p.grad.detach().float().norm(2) |
| for p in self.parameters() if p.grad is not None] |
| if grads: |
| total_norm = torch.norm(torch.stack(grads), p=2) |
| self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False) |
|
|
| def on_train_epoch_end(self): |
| if False: |
| if self.train_auc.compute() is not None: |
| self.log("train/auroc", self.train_auc.compute(), prog_bar=True) |
| self.train_auc.reset() |
|
|
| def on_validation_epoch_end(self): |
| if False: |
| if self.val_auc.compute() is not None: |
| self.log("val/auroc", self.val_auc.compute(), prog_bar=True) |
| self.val_auc.reset() |
|
|
| def on_test_epoch_end(self): |
| if False: |
| if self.test_auc.compute() is not None: |
| self.log("test/auroc", self.test_auc.compute(), prog_bar=True) |
| self.test_auc.reset() |
|
|
| def configure_optimizers(self): |
| |
| opt = torch.optim.AdamW( |
| self.parameters(), |
| lr=self.hparams.lr, |
| weight_decay=self.hparams.weight_decay, |
| ) |
| |
| sch = torch.optim.lr_scheduler.CosineAnnealingLR( |
| opt, T_max=max(self.trainer.max_epochs, 1) |
| ) |
| return { |
| "optimizer": opt, |
| "lr_scheduler": {"scheduler": sch, "interval": "epoch"}, |
| } |
|
|