DPACMAN / dpacman /classifier /model.py
svincoff's picture
added dropout and overfit prevention
9da03b7
"""
Lightning Module for the binding model.
"""
import torch
from torch import nn
from lightning import LightningModule
from dpacman.utils.models import set_seed
from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
set_seed()
class LocalCNN(nn.Module):
def __init__(self, dim: int = 256, kernel_size: int = 3, dropout=0.1):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
self.act = nn.GELU()
self.ln = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
# x: (batch, L, dim)
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
out = self.act(out)
out = self.dropout(out) # dropout before the layer norm
out = out.transpose(1, 2) # → (batch, L, dim)
return self.ln(out + x) # residual
class CrossModalBlock(nn.Module):
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.1):
super().__init__()
# self-attention for both sides
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
self.do_sa_b = nn.Dropout(dropout)
self.do_sa_g = nn.Dropout(dropout)
# first layer norms
self.ln_b1 = nn.LayerNorm(dim)
self.ln_g1 = nn.LayerNorm(dim)
# first feed forward networks
self.ffn_b1 = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
)
self.ffn_g1 = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
)
self.do_ffn_b1 = nn.Dropout(dropout)
self.do_ffn_g1 = nn.Dropout(dropout)
self.ln_b2 = nn.LayerNorm(dim)
self.ln_g2 = nn.LayerNorm(dim)
# 2) reciprocal cross-attn: g<-b and b<-g
# DNA/GLM updated by attending to Binder
self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
self.do_rca_g = nn.Dropout(dropout)
self.ln_g3_RCA = nn.LayerNorm(dim)
self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim))
self.do_ffn_g2 = nn.Dropout(dropout)
self.ln_g4_RCA = nn.LayerNorm(dim)
# Binder updated by attending to DNA/GLM
self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
self.do_rca_b = nn.Dropout(dropout)
self.ln_b3_RCA = nn.LayerNorm(dim)
self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim))
self.do_ffn_b2 = nn.Dropout(dropout)
self.ln_b4_RCA = nn.LayerNorm(dim)
# cross attention (binder queries, glm keys/values)
# so the NDA path is updated by the transcription factors
self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
self.do_g2b2 = nn.Dropout(dropout)
self.ln_g5 = nn.LayerNorm(dim)
self.ffn_g3 = nn.Sequential(
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
)
self.do_ffn_g3 = nn.Dropout(dropout)
self.ln_g6 = nn.LayerNorm(dim)
def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
"""
binder: (batch, Lb, dim)
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
returns: updated binder representation (batch, Lb, dim) and gLM representation
"""
# 1) Self-attention and feed-forward networks for binder and DNA
# binder: self-attn + ffn
b = binder
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
b = self.ln_b1(b + self.do_sa_b(b_sa))
b_ff = self.ffn_b1(b)
b = self.ln_b2(b + self.do_ffn_b1(b_ff))
# glm: self-attn + ffn
g = glm
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
g = self.ln_g1(g + self.do_sa_g(g_sa))
g_ff = self.ffn_g1(g)
g = self.ln_g2(g + self.do_ffn_g1(g_ff))
# 2a) Reciprocal Cross-Attention:
# DNA updated by attending to Binder (Q=g, K=b, V=b)
# Binder updated by attending to DNA (Q=b, K=g, V=g)
g_ca, _ = self.cross_g2b_1_RCA(
g, b, b, key_padding_mask=binder_kpm_mask
# torch MultiheadAttention expects key_padding_mask=True for PADs;
# invert if your mask is True=keep:
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
)
g = self.ln_g3_RCA(g + self.do_rca_g(g_ca))
g = self.ln_g4_RCA(g + self.do_ffn_g2(self.ffn_g2_RCA(g)))
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
b_ca, _ = self.cross_b2g_1_RCA(
b, g, g, key_padding_mask=glm_kpm_mask
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
)
b = self.ln_b3_RCA(b + self.do_rca_b(b_ca))
b = self.ln_b4_RCA(b + self.do_ffn_b2(self.ffn_b2_RCA(b)))
# cross-attention: glm queries binder and glm embeddings are updated
g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
g = self.ln_g5(g + self.do_g2b2(g_to_b_ca))
g_ff = self.ffn_g3(g)
g = self.ln_g6(g + self.do_ffn_g3(g_ff))
return b, g # (batch, Lb, dim)
class DimCompressor(nn.Module):
"""
Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
If in_dim == out_dim, behaves as identity.
"""
def __init__(self, in_dim: int, out_dim: int = 256):
super().__init__()
if in_dim == out_dim:
self.net = nn.Identity()
else:
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
self.net = nn.Sequential(
nn.LayerNorm(in_dim),
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Linear(hidden, out_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, L, in_dim)
return self.net(x)
class BindPredictor(LightningModule):
def __init__(
self,
# input_dim: int = 256, # OLD: single input dim
binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
compressed_dim: int = 256, # NEW: learnable compressed dim
hidden_dim: int = 256,
heads: int = 8,
num_layers: int = 4,
lr: float = 1e-4,
alpha: float = 20,
gamma: float = 20,
dropout: float = 0,
use_local_cnn_on_glm: bool = True,
weight_decay: float = 0.01,
loss_type = "mixed"
):
# Init
super(BindPredictor, self).__init__()
self.save_hyperparameters()
# Learnable compressor for binder -> 256, then project to hidden
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
self.dropout_b1 = nn.Dropout(dropout)
self.act = nn.GELU()
# GLM side stays 256 -> hidden
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
self.dropout_g1 = nn.Dropout(dropout)
self.use_local_cnn = use_local_cnn_on_glm
self.local_cnn = LocalCNN(hidden_dim, dropout=self.hparams.dropout) if use_local_cnn_on_glm else nn.Identity()
self.layers = nn.ModuleList(
[CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
)
#self.ln_out = nn.LayerNorm(hidden_dim)
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
"""
binder_emb: (B, Lb, binder_input_dim)
glm_emb: (B, Lg, glm_input_dim)
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
"""
# Binder: learnable compression → 256 → hidden
b = self.binder_compress(binder_emb) # (B, Lb, 256)
b = self.proj_binder(b) # (B, Lb, hidden_dim)
b = self.dropout_b1(self.act(b))
# GLM: project → hidden, add local CNN context
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
g = self.dropout_g1(self.act(g))
if self.use_local_cnn:
g = self.local_cnn(g)
# Cross-modal blocks: update binder states using GLM
for layer in self.layers:
b, g = layer(b, g, binder_mask, glm_mask) # (B, Lb, hidden_dim)
# Predict per-nucleotide logits on the GLM tokens:
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
logits = self.head(g).squeeze(
-1
)
return logits
# ----- Lightning hooks -----
def training_step(self, batch, batch_idx):
"""
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
Colator returns a dictionary with:
"binder_emb" # [B, Lb_max, Db]
"binder_kpm" # [B, Lb_max]
"glm_emb" # [B, Lg_max, Dg]
"glm_kpm" # [B, Lg_max]
"labels" # [B, Lg_max]
"ID"
"tr_sequence"
"dna_sequence"
}
"""
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
loss = calculate_loss(
logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
)
self.log(
"train/loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
batch_size=logits.size(0),
)
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
)
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
)
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
self.log("train/auprc_0v1",
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
self.log("train/auroc_0v1",
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
# (optional) also log class counts so you can sanity-check balance
self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
loss = calculate_loss(
logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
)
self.log(
"val/loss",
loss,
on_step=False,
on_epoch=True,
prog_bar=True,
batch_size=logits.size(0),
)
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
)
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
)
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
self.log("val/auprc_0v1",
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
self.log("val/auroc_0v1",
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
return loss
def test_step(self, batch, batch_idx):
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
loss = calculate_loss(
logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
)
self.log(
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
)
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
)
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
)
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
self.log("test/auprc_0v1",
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
self.log("test/auroc_0v1",
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
return loss
def predict_step(self, batch, batch_idx, dataloader_idx=0):
logits = self.forward(batch["binder_emb"], batch["glm_emb"],
batch["binder_kpm"], batch["glm_kpm"]).squeeze(-1) # (B,L)
valid = ~batch["glm_kpm"] # (B,L)
return {
"ids": batch["ID"], # list[str]
"logits": logits.detach().cpu(), # (B,Lmax) padded
"valid": valid.detach().cpu(), # (B,Lmax) booleans
"labels": batch["labels"].detach().cpu(), # (B,Lmax) padded
}
def on_before_optimizer_step(self, optimizer):
# Compute global L2 norm of all parameter gradients (ignores None grads)
grads = []
for p in self.parameters():
if p.grad is not None:
# .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms
grads.append(p.grad.detach().float().norm(2))
if grads:
total_norm = torch.norm(torch.stack(grads), p=2)
self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True)
def on_after_backward(self):
grads = [p.grad.detach().float().norm(2)
for p in self.parameters() if p.grad is not None]
if grads:
total_norm = torch.norm(torch.stack(grads), p=2)
self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False)
def on_train_epoch_end(self):
if False:
if self.train_auc.compute() is not None:
self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
self.train_auc.reset()
def on_validation_epoch_end(self):
if False:
if self.val_auc.compute() is not None:
self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
self.val_auc.reset()
def on_test_epoch_end(self):
if False:
if self.test_auc.compute() is not None:
self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
self.test_auc.reset()
def configure_optimizers(self):
# AdamW + cosine as a sensible default
opt = torch.optim.AdamW(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
# Scheduler optional—comment out if you prefer fixed LR
sch = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=max(self.trainer.max_epochs, 1)
)
return {
"optimizer": opt,
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
}