""" Lightning Module for the binding model. """ import torch from torch import nn from lightning import LightningModule from dpacman.utils.models import set_seed from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits set_seed() class LocalCNN(nn.Module): def __init__(self, dim: int = 256, kernel_size: int = 3): super().__init__() padding = kernel_size // 2 self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding) self.act = nn.GELU() self.ln = nn.LayerNorm(dim) def forward(self, x: torch.Tensor): # x: (batch, L, dim) out = self.conv(x.transpose(1, 2)) # → (batch, dim, L) out = self.act(out) out = out.transpose(1, 2) # → (batch, L, dim) return self.ln(out + x) # residual class CrossModalBlock(nn.Module): def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0): super().__init__() # self-attention for both sides self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True) self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True) # first layer norms self.ln_b1 = nn.LayerNorm(dim) self.ln_g1 = nn.LayerNorm(dim) # first feed forward networks self.ffn_b1 = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) ) self.ffn_g1 = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) ) self.ln_b2 = nn.LayerNorm(dim) self.ln_g2 = nn.LayerNorm(dim) # 2) reciprocal cross-attn: g<-b and b<-g # DNA/GLM updated by attending to Binder self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout) self.ln_g3_RCA = nn.LayerNorm(dim) self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim)) self.ln_g4_RCA = nn.LayerNorm(dim) # Binder updated by attending to DNA/GLM self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout) self.ln_b3_RCA = nn.LayerNorm(dim) self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim)) self.ln_b4_RCA = nn.LayerNorm(dim) # cross attention (binder queries, glm keys/values) # so the NDA path is updated by the transcriptoin factors self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True) self.ln_g5 = nn.LayerNorm(dim) self.ffn_g3 = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim) ) self.ln_g6 = nn.LayerNorm(dim) def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None): """ binder: (batch, Lb, dim) glm: (batch, Lg, dim) -- has passed through its local CNN beforehand returns: updated binder representation (batch, Lb, dim) """ # 1) Self-attentino and feed-forward networks for binder and DNA # binder: self-attn + ffn b = binder b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask) b = self.ln_b1(b + b_sa) b_ff = self.ffn_b1(b) b = self.ln_b2(b + b_ff) # glm: self-attn + ffn g = glm g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask) g = self.ln_g1(g + g_sa) g_ff = self.ffn_g1(g) g = self.ln_g2(g + g_ff) # 2a) Reciprocal Cross-Attention: # DNA updated by attending to Binder (Q=g, K=b, V=b) # Binder updated by attending to DNA (Q=b, K=g, V=g) g_ca, _ = self.cross_g2b_1_RCA( g, b, b, key_padding_mask=binder_kpm_mask # torch MultiheadAttention expects key_padding_mask=True for PADs; # invert if your mask is True=keep: # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None ) g = self.ln_g3_RCA(g + g_ca) g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g)) # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g) b_ca, _ = self.cross_b2g_1_RCA( b, g, g, key_padding_mask=glm_kpm_mask # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None ) b = self.ln_b3_RCA(b + b_ca) b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b)) # cross-attention: glm queries binder and glm embeddings are updated g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask) g = self.ln_g5(g + g_to_b_ca) g_ff = self.ffn_g3(g) g = self.ln_g6(g + g_ff) return b, g # (batch, Lb, dim) class DimCompressor(nn.Module): """ Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256). If in_dim == out_dim, behaves as identity. """ def __init__(self, in_dim: int, out_dim: int = 256): super().__init__() if in_dim == out_dim: self.net = nn.Identity() else: hidden = max(out_dim * 2, (in_dim + out_dim) // 2) self.net = nn.Sequential( nn.LayerNorm(in_dim), nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, L, in_dim) return self.net(x) class BindPredictor(LightningModule): def __init__( self, # input_dim: int = 256, # OLD: single input dim binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280) glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256) compressed_dim: int = 256, # NEW: learnable compressed dim hidden_dim: int = 256, heads: int = 8, num_layers: int = 4, lr: float = 1e-4, alpha: float = 20, gamma: float = 20, dropout: float = 0, use_local_cnn_on_glm: bool = True, weight_decay: float = 0.01, loss_type = "mixed" ): # Init super(BindPredictor, self).__init__() self.save_hyperparameters() # Learnable compressor for binder -> 256, then project to hidden self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim) self.proj_binder = nn.Linear(compressed_dim, hidden_dim) # GLM side stays 256 -> hidden self.proj_glm = nn.Linear(glm_input_dim, hidden_dim) self.use_local_cnn = use_local_cnn_on_glm self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity() self.layers = nn.ModuleList( [CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)] ) #self.ln_out = nn.LayerNorm(hidden_dim) # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP) def forward(self, binder_emb, glm_emb, binder_mask, glm_mask): """ binder_emb: (B, Lb, binder_input_dim) glm_emb: (B, Lg, glm_input_dim) Returns per-nucleotide logits for the GLM sequence: (B, Lg) """ # Binder: learnable compression → 256 → hidden b = self.binder_compress(binder_emb) # (B, Lb, 256) b = self.proj_binder(b) # (B, Lb, hidden_dim) # GLM: project → hidden, add local CNN context g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim) if self.use_local_cnn: g = self.local_cnn(g) # Cross-modal blocks: update binder states using GLM for layer in self.layers: b, g = layer(b, g, binder_mask, glm_mask) # (B, Lb, hidden_dim) # Predict per-nucleotide logits on the GLM tokens: # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head) logits = self.head(g).squeeze( -1 ) return logits # ----- Lightning hooks ----- def training_step(self, batch, batch_idx): """ Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator. Colator returns a dictionary with: "binder_emb" # [B, Lb_max, Db] "binder_kpm" # [B, Lb_max] "glm_emb" # [B, Lg_max, Dg] "glm_kpm" # [B, Lg_max] "labels" # [B, Lg_max] "ID" "tr_sequence" "dna_sequence" } """ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) loss = calculate_loss( logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type ) self.log( "train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=logits.size(0), ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP self.log("train/auprc_0v1", ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) self.log("train/auroc_0v1", auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) # (optional) also log class counts so you can sanity-check balance self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True) self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) loss = calculate_loss( logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type ) self.log( "val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=logits.size(0), ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP self.log("val/auprc_0v1", ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) self.log("val/auroc_0v1", auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) return loss def test_step(self, batch, batch_idx): logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) loss = calculate_loss( logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type ) self.log( "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0) ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP self.log("test/auprc_0v1", ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) self.log("test/auroc_0v1", auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) return loss def predict_step(self, batch, batch_idx, dataloader_idx=0): logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]).squeeze(-1) # (B,L) valid = ~batch["glm_kpm"] # (B,L) return { "ids": batch["ID"], # list[str] "logits": logits.detach().cpu(), # (B,Lmax) padded "valid": valid.detach().cpu(), # (B,Lmax) booleans "labels": batch["labels"].detach().cpu(), # (B,Lmax) padded } def on_before_optimizer_step(self, optimizer): # Compute global L2 norm of all parameter gradients (ignores None grads) grads = [] for p in self.parameters(): if p.grad is not None: # .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms grads.append(p.grad.detach().float().norm(2)) if grads: total_norm = torch.norm(torch.stack(grads), p=2) self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True) def on_after_backward(self): grads = [p.grad.detach().float().norm(2) for p in self.parameters() if p.grad is not None] if grads: total_norm = torch.norm(torch.stack(grads), p=2) self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False) def on_train_epoch_end(self): if False: if self.train_auc.compute() is not None: self.log("train/auroc", self.train_auc.compute(), prog_bar=True) self.train_auc.reset() def on_validation_epoch_end(self): if False: if self.val_auc.compute() is not None: self.log("val/auroc", self.val_auc.compute(), prog_bar=True) self.val_auc.reset() def on_test_epoch_end(self): if False: if self.test_auc.compute() is not None: self.log("test/auroc", self.test_auc.compute(), prog_bar=True) self.test_auc.reset() def configure_optimizers(self): # AdamW + cosine as a sensible default opt = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, ) # Scheduler optional—comment out if you prefer fixed LR sch = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=max(self.trainer.max_epochs, 1) ) return { "optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}, }