""" Script for using the model just for inference. """ from typing import Any, Dict, List, Optional, Tuple import hydra from hydra.core.hydra_config import HydraConfig import torch import rootutils import lightning as L from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig from pathlib import Path import pandas as pd from dpacman.classifier.loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits import pickle root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from dpacman.utils import ( RankedLogger, extras, get_metric_value, instantiate_callbacks, instantiate_loggers, log_hyperparameters, task_wrapper, ) log = RankedLogger(__name__, rank_zero_only=True) def h100_settings(): # Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff torch.set_float32_matmul_precision("high") # or "medium" for even more speed # (optional; older PyTorch toggle) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def flatten_preds(pred_batches): """ Flatten what the model predicts, which includes: "ids": batch["ID"], # list[str] or list "logits": logits.detach().cpu(), # (B, Lmax) padded "valid": valid.detach().cpu(), # (B, Lmax) booleans "labels" """ out = [] for b in pred_batches: ids, logits, valid, labels = b["ids"], b["logits"], b["valid"], b["labels"] for i, id_ in enumerate(ids): L = int(valid[i].sum().item()) # strip padding trim_logits = logits[i, :L].numpy() out.append({"ID": id_, "logits": trim_logits, "labels": labels[i, :L].numpy()}) return out @task_wrapper def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """trains model given checkpoint on a datamodule train set. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. :param cfg: DictConfig configuration composed by Hydra. :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. """ # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): L.seed_everything(cfg.seed, workers=True) log.info(f"Instantiating datamodule <{cfg.data_module._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module) log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( cfg.trainer, callbacks=callbacks, logger=logger ) object_dict = { "cfg": cfg, "datamodule": datamodule, "model": model, "callbacks": callbacks, "logger": logger, "trainer": trainer, } if logger: log.info("Logging hyperparameters!") log_hyperparameters(object_dict) if cfg.get("test"): log.info("Starting testing!") ckpt_path = cfg.ckpt_path if ckpt_path == "": log.warning("No ckpt path was passed! Cannot continue") return trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) pred_batches = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=True) out = flatten_preds(pred_batches) # make output dir output_dir = Path(HydraConfig.get().run.dir) save_path = output_dir / "predictions.pkl" with open(save_path, "wb") as f: pickle.dump(out, f) # iterate through out and recalculate AUC, AUPRC, loss - only if there are labels # only if the user actually passed scores; otherwise don't bother if not(datamodule.test_dataset.fake_scores): for i, d in enumerate(out): loss = calculate_loss( torch.tensor(d["logits"]), torch.tensor(d["labels"]), None, None, alpha=cfg.model.alpha, gamma=cfg.model.gamma ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, ap_thresholds = auprc_zeros_vs_ones_from_logits( torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, auc_thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99 ) out[i]["loss"] = loss.item() if loss.numel()>0 else None out[i]["auprc"] = ap.item() if ap.numel()>0 else None out[i]["auroc"] = auc.item() if auc.numel()>0 else None out[i]["n_pos"] = n_pos out[i]["n_neg"] = n_neg out[i]["precision"] = precision.numpy() if precision.numel()>0 else None out[i]["recall"] = recall.numpy() if recall.numel()>0 else None out[i]["auprc_thresholds"] = ap_thresholds.numpy() if ap_thresholds.numel()>0 else None out[i]["auc_thresholds"] = auc_thresolds.numpy() if auc_thresolds.numel()>0 else None out[i]["tpr"] = tpr out[i]["fpr"] = fpr # Summary CSV (no big arrays inside) summary_rows = [] for d in out: summary_rows.append({ "ID": d["ID"], "loss": d.get("loss"), "auprc": d.get("auprc"), "auroc": d.get("auroc"), "n_pos": d.get("n_pos"), "n_neg": d.get("n_neg"), }) save_path = output_dir / "summary.csv" pd.DataFrame(summary_rows).to_csv(output_dir / "summary.csv", index=False) # save it log.info(f"Saved eval/predict results to {save_path}") test_metrics = trainer.callback_metrics # merge train and test metrics metric_dict = {**test_metrics} return metric_dict, object_dict @hydra.main( version_base="1.3", config_path=str(root / "configs"), config_name="eval.yaml" ) def main(cfg: DictConfig) -> None: """Main entry point for evaluation. :param cfg: DictConfig configuration composed by Hydra. """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) extras(cfg) h100_settings() # try using settings for faster h100s training # train the model metric_dict, _ = predict(cfg) # safely retrieve metric value for hydra-based hyperparameter optimization metric_value = get_metric_value( metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") ) # return optimized metric return metric_value if __name__ == "__main__": main()