File size: 7,597 Bytes
7b33404 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """
Script for using the model just for inference.
"""
from typing import Any, Dict, List, Optional, Tuple
import hydra
from hydra.core.hydra_config import HydraConfig
import torch
import rootutils
import lightning as L
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from pathlib import Path
import pandas as pd
from dpacman.classifier.loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
import pickle
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from dpacman.utils import (
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
log = RankedLogger(__name__, rank_zero_only=True)
def h100_settings():
# Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff
torch.set_float32_matmul_precision("high") # or "medium" for even more speed
# (optional; older PyTorch toggle)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def flatten_preds(pred_batches):
"""
Flatten what the model predicts, which includes:
"ids": batch["ID"], # list[str] or list
"logits": logits.detach().cpu(), # (B, Lmax) padded
"valid": valid.detach().cpu(), # (B, Lmax) booleans
"labels"
"""
out = []
for b in pred_batches:
ids, logits, valid, labels = b["ids"], b["logits"], b["valid"], b["labels"]
for i, id_ in enumerate(ids):
L = int(valid[i].sum().item()) # strip padding
trim_logits = logits[i, :L].numpy()
out.append({"ID": id_, "logits": trim_logits, "labels": labels[i, :L].numpy()})
return out
@task_wrapper
def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""trains model given checkpoint on a datamodule train set.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data_module._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = cfg.ckpt_path
if ckpt_path == "":
log.warning("No ckpt path was passed! Cannot continue")
return
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
pred_batches = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=True)
out = flatten_preds(pred_batches)
# make output dir
output_dir = Path(HydraConfig.get().run.dir)
save_path = output_dir / "predictions.pkl"
with open(save_path, "wb") as f:
pickle.dump(out, f)
# iterate through out and recalculate AUC, AUPRC, loss - only if there are labels
# only if the user actually passed scores; otherwise don't bother
if not(datamodule.test_dataset.fake_scores):
for i, d in enumerate(out):
loss = calculate_loss(
torch.tensor(d["logits"]), torch.tensor(d["labels"]), None, None, alpha=cfg.model.alpha, gamma=cfg.model.gamma
)
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
ap, n_pos, n_neg, precision, recall, ap_thresholds = auprc_zeros_vs_ones_from_logits(
torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
)
auc, n_pos, n_neg, tpr, fpr, auc_thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
)
out[i]["loss"] = loss.item() if loss.numel()>0 else None
out[i]["auprc"] = ap.item() if ap.numel()>0 else None
out[i]["auroc"] = auc.item() if auc.numel()>0 else None
out[i]["n_pos"] = n_pos
out[i]["n_neg"] = n_neg
out[i]["precision"] = precision.numpy() if precision.numel()>0 else None
out[i]["recall"] = recall.numpy() if recall.numel()>0 else None
out[i]["auprc_thresholds"] = ap_thresholds.numpy() if ap_thresholds.numel()>0 else None
out[i]["auc_thresholds"] = auc_thresolds.numpy() if auc_thresolds.numel()>0 else None
out[i]["tpr"] = tpr
out[i]["fpr"] = fpr
# Summary CSV (no big arrays inside)
summary_rows = []
for d in out:
summary_rows.append({
"ID": d["ID"],
"loss": d.get("loss"),
"auprc": d.get("auprc"),
"auroc": d.get("auroc"),
"n_pos": d.get("n_pos"),
"n_neg": d.get("n_neg"),
})
save_path = output_dir / "summary.csv"
pd.DataFrame(summary_rows).to_csv(output_dir / "summary.csv", index=False)
# save it
log.info(f"Saved eval/predict results to {save_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**test_metrics}
return metric_dict, object_dict
@hydra.main(
version_base="1.3", config_path=str(root / "configs"), config_name="eval.yaml"
)
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
h100_settings() # try using settings for faster h100s training
# train the model
metric_dict, _ = predict(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()
|