DPACMAN / dpacman /scripts /eval.py
svincoff's picture
eval mode, fixed, full binary mode
7b33404
"""
Script for using the model just for inference.
"""
from typing import Any, Dict, List, Optional, Tuple
import hydra
from hydra.core.hydra_config import HydraConfig
import torch
import rootutils
import lightning as L
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from pathlib import Path
import pandas as pd
from dpacman.classifier.loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
import pickle
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from dpacman.utils import (
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
log = RankedLogger(__name__, rank_zero_only=True)
def h100_settings():
# Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff
torch.set_float32_matmul_precision("high") # or "medium" for even more speed
# (optional; older PyTorch toggle)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def flatten_preds(pred_batches):
"""
Flatten what the model predicts, which includes:
"ids": batch["ID"], # list[str] or list
"logits": logits.detach().cpu(), # (B, Lmax) padded
"valid": valid.detach().cpu(), # (B, Lmax) booleans
"labels"
"""
out = []
for b in pred_batches:
ids, logits, valid, labels = b["ids"], b["logits"], b["valid"], b["labels"]
for i, id_ in enumerate(ids):
L = int(valid[i].sum().item()) # strip padding
trim_logits = logits[i, :L].numpy()
out.append({"ID": id_, "logits": trim_logits, "labels": labels[i, :L].numpy()})
return out
@task_wrapper
def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""trains model given checkpoint on a datamodule train set.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data_module._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer, callbacks=callbacks, logger=logger
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = cfg.ckpt_path
if ckpt_path == "":
log.warning("No ckpt path was passed! Cannot continue")
return
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
pred_batches = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=True)
out = flatten_preds(pred_batches)
# make output dir
output_dir = Path(HydraConfig.get().run.dir)
save_path = output_dir / "predictions.pkl"
with open(save_path, "wb") as f:
pickle.dump(out, f)
# iterate through out and recalculate AUC, AUPRC, loss - only if there are labels
# only if the user actually passed scores; otherwise don't bother
if not(datamodule.test_dataset.fake_scores):
for i, d in enumerate(out):
loss = calculate_loss(
torch.tensor(d["logits"]), torch.tensor(d["labels"]), None, None, alpha=cfg.model.alpha, gamma=cfg.model.gamma
)
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
ap, n_pos, n_neg, precision, recall, ap_thresholds = auprc_zeros_vs_ones_from_logits(
torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
)
auc, n_pos, n_neg, tpr, fpr, auc_thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
)
out[i]["loss"] = loss.item() if loss.numel()>0 else None
out[i]["auprc"] = ap.item() if ap.numel()>0 else None
out[i]["auroc"] = auc.item() if auc.numel()>0 else None
out[i]["n_pos"] = n_pos
out[i]["n_neg"] = n_neg
out[i]["precision"] = precision.numpy() if precision.numel()>0 else None
out[i]["recall"] = recall.numpy() if recall.numel()>0 else None
out[i]["auprc_thresholds"] = ap_thresholds.numpy() if ap_thresholds.numel()>0 else None
out[i]["auc_thresholds"] = auc_thresolds.numpy() if auc_thresolds.numel()>0 else None
out[i]["tpr"] = tpr
out[i]["fpr"] = fpr
# Summary CSV (no big arrays inside)
summary_rows = []
for d in out:
summary_rows.append({
"ID": d["ID"],
"loss": d.get("loss"),
"auprc": d.get("auprc"),
"auroc": d.get("auroc"),
"n_pos": d.get("n_pos"),
"n_neg": d.get("n_neg"),
})
save_path = output_dir / "summary.csv"
pd.DataFrame(summary_rows).to_csv(output_dir / "summary.csv", index=False)
# save it
log.info(f"Saved eval/predict results to {save_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**test_metrics}
return metric_dict, object_dict
@hydra.main(
version_base="1.3", config_path=str(root / "configs"), config_name="eval.yaml"
)
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
h100_settings() # try using settings for faster h100s training
# train the model
metric_dict, _ = predict(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()