| """ |
| Code for baseline model to compare the classifier to |
| """ |
|
|
| from lightning import LightningModule |
| import torch |
| import torch.nn as nn |
| from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits |
| from .model import DimCompressor |
|
|
| class BaselineBindPredictor(LightningModule): |
| """ |
| Baseline predictor: simple MLP that just concatenates the embeddings and outputs per-token predictions. |
| """ |
| def __init__( |
| self, |
| |
| binder_input_dim: int = 1280, |
| glm_input_dim: int = 256, |
| compressed_dim: int = 256, |
| hidden_dim: int = 256, |
| lr: float = 1e-4, |
| alpha: float = 20, |
| gamma: float = 20, |
| dropout: float = 0, |
| weight_decay: float = 0.01, |
| loss_type: str = "mixed" |
| ): |
| |
| super(BaselineBindPredictor, self).__init__() |
| self.save_hyperparameters() |
|
|
| |
| self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim) |
| |
| self.mlp = torch.nn.Sequential( |
| torch.nn.Linear(compressed_dim, hidden_dim), |
| torch.nn.ReLU(), |
| torch.nn.Linear(hidden_dim, 1), |
| torch.nn.ReLU(), |
| ) |
| |
| def forward(self, binder_emb, glm_emb, binder_mask, glm_mask): |
| """ |
| binder_emb: (B, Lb, binder_input_dim) |
| glm_emb: (B, Lg, glm_input_dim) |
| Returns per-nucleotide logits for the GLM sequence: (B, Lg) |
| """ |
| |
| b = self.binder_compress(binder_emb) |
| |
| |
| lg = glm_emb.shape[1] |
| concat_embeddings = torch.concat((glm_emb,b), dim=1) |
| |
| |
| logits = self.mlp(concat_embeddings) |
|
|
| |
| logits = logits[:,0:lg,:].squeeze( |
| -1 |
| ) |
| return logits |
| |
| |
| def training_step(self, batch, batch_idx): |
| """ |
| Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator. |
| Colator returns a dictionary with: |
| "binder_emb" # [B, Lb_max, Db] |
| "binder_kpm" # [B, Lb_max] |
| "glm_emb" # [B, Lg_max, Dg] |
| "glm_kpm" # [B, Lg_max] |
| "labels" # [B, Lg_max] |
| "ID" |
| "tr_sequence" |
| "dna_sequence" |
| } |
| """ |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) |
| loss = calculate_loss( |
| logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type |
| ) |
| self.log( |
| "train/loss", |
| loss, |
| on_step=True, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=logits.size(0), |
| ) |
| |
| |
| ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| |
| self.log("train/auprc_0v1", |
| ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| self.log("train/auroc_0v1", |
| auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| |
| |
| self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True) |
| self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True) |
|
|
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) |
| loss = calculate_loss( |
| logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type |
| ) |
| self.log( |
| "val/loss", |
| loss, |
| on_step=False, |
| on_epoch=True, |
| prog_bar=True, |
| batch_size=logits.size(0), |
| ) |
| |
| |
| ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| |
| self.log("val/auprc_0v1", |
| ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| self.log("val/auroc_0v1", |
| auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| return loss |
|
|
| def test_step(self, batch, batch_idx): |
| logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) |
| loss = calculate_loss( |
| logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type |
| ) |
| self.log( |
| "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0) |
| ) |
| |
| |
| ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 |
| ) |
| |
| self.log("test/auprc_0v1", |
| ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| self.log("test/auroc_0v1", |
| auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), |
| on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) |
| return loss |
| |
| def on_before_optimizer_step(self, optimizer): |
| |
| grads = [] |
| for p in self.parameters(): |
| if p.grad is not None: |
| |
| grads.append(p.grad.detach().float().norm(2)) |
| if grads: |
| total_norm = torch.norm(torch.stack(grads), p=2) |
| self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True) |
| |
| def on_after_backward(self): |
| grads = [p.grad.detach().float().norm(2) |
| for p in self.parameters() if p.grad is not None] |
| if grads: |
| total_norm = torch.norm(torch.stack(grads), p=2) |
| self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False) |
|
|
| def on_train_epoch_end(self): |
| if False: |
| if self.train_auc.compute() is not None: |
| self.log("train/auroc", self.train_auc.compute(), prog_bar=True) |
| self.train_auc.reset() |
|
|
| def on_validation_epoch_end(self): |
| if False: |
| if self.val_auc.compute() is not None: |
| self.log("val/auroc", self.val_auc.compute(), prog_bar=True) |
| self.val_auc.reset() |
|
|
| def on_test_epoch_end(self): |
| if False: |
| if self.test_auc.compute() is not None: |
| self.log("test/auroc", self.test_auc.compute(), prog_bar=True) |
| self.test_auc.reset() |
|
|
| def configure_optimizers(self): |
| |
| opt = torch.optim.AdamW( |
| self.parameters(), |
| lr=self.hparams.lr, |
| weight_decay=self.hparams.weight_decay, |
| ) |
| |
| sch = torch.optim.lr_scheduler.CosineAnnealingLR( |
| opt, T_max=max(self.trainer.max_epochs, 1) |
| ) |
| return { |
| "optimizer": opt, |
| "lr_scheduler": {"scheduler": sch, "interval": "epoch"}, |
| } |