| """ |
| Script for using the model just for inference. |
| """ |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import hydra |
| from hydra.core.hydra_config import HydraConfig |
| import torch |
| import rootutils |
| import lightning as L |
| from lightning import Callback, LightningDataModule, LightningModule, Trainer |
| from lightning.pytorch.loggers import Logger |
| from omegaconf import DictConfig |
| from pathlib import Path |
| import pandas as pd |
| from dpacman.classifier.loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits |
| import pickle |
|
|
| root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
|
|
| from dpacman.utils import ( |
| RankedLogger, |
| extras, |
| get_metric_value, |
| instantiate_callbacks, |
| instantiate_loggers, |
| log_hyperparameters, |
| task_wrapper, |
| ) |
|
|
| log = RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
| def h100_settings(): |
| |
| torch.set_float32_matmul_precision("high") |
|
|
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| def flatten_preds(pred_batches): |
| """ |
| Flatten what the model predicts, which includes: |
| "ids": batch["ID"], # list[str] or list |
| "logits": logits.detach().cpu(), # (B, Lmax) padded |
| "valid": valid.detach().cpu(), # (B, Lmax) booleans |
| "labels" |
| """ |
| out = [] |
| for b in pred_batches: |
| ids, logits, valid, labels = b["ids"], b["logits"], b["valid"], b["labels"] |
| for i, id_ in enumerate(ids): |
| L = int(valid[i].sum().item()) |
| trim_logits = logits[i, :L].numpy() |
| out.append({"ID": id_, "logits": trim_logits, "labels": labels[i, :L].numpy()}) |
| return out |
|
|
| @task_wrapper |
| def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| """trains model given checkpoint on a datamodule train set. |
| |
| This method is wrapped in optional @task_wrapper decorator, that controls the behavior during |
| failure. Useful for multiruns, saving info about the crash, etc. |
| |
| :param cfg: DictConfig configuration composed by Hydra. |
| :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. |
| """ |
| |
| if cfg.get("seed"): |
| L.seed_everything(cfg.seed, workers=True) |
|
|
| log.info(f"Instantiating datamodule <{cfg.data_module._target_}>") |
| datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module) |
|
|
| log.info(f"Instantiating model <{cfg.model._target_}>") |
| model: LightningModule = hydra.utils.instantiate(cfg.model) |
|
|
| log.info("Instantiating callbacks...") |
| callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) |
|
|
| log.info("Instantiating loggers...") |
| logger: List[Logger] = instantiate_loggers(cfg.get("logger")) |
|
|
| log.info(f"Instantiating trainer <{cfg.trainer._target_}>") |
| trainer: Trainer = hydra.utils.instantiate( |
| cfg.trainer, callbacks=callbacks, logger=logger |
| ) |
|
|
| object_dict = { |
| "cfg": cfg, |
| "datamodule": datamodule, |
| "model": model, |
| "callbacks": callbacks, |
| "logger": logger, |
| "trainer": trainer, |
| } |
|
|
| if logger: |
| log.info("Logging hyperparameters!") |
| log_hyperparameters(object_dict) |
|
|
| if cfg.get("test"): |
| log.info("Starting testing!") |
| ckpt_path = cfg.ckpt_path |
| if ckpt_path == "": |
| log.warning("No ckpt path was passed! Cannot continue") |
| return |
| trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) |
| |
| pred_batches = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=True) |
| out = flatten_preds(pred_batches) |
| |
| |
| output_dir = Path(HydraConfig.get().run.dir) |
| save_path = output_dir / "predictions.pkl" |
| with open(save_path, "wb") as f: |
| pickle.dump(out, f) |
| |
| |
| |
| if not(datamodule.test_dataset.fake_scores): |
| for i, d in enumerate(out): |
| loss = calculate_loss( |
| torch.tensor(d["logits"]), torch.tensor(d["labels"]), None, None, alpha=cfg.model.alpha, gamma=cfg.model.gamma |
| ) |
| |
| ap, n_pos, n_neg, precision, recall, ap_thresholds = auprc_zeros_vs_ones_from_logits( |
| torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99 |
| ) |
| auc, n_pos, n_neg, tpr, fpr, auc_thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( |
| torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99 |
| ) |
| out[i]["loss"] = loss.item() if loss.numel()>0 else None |
| out[i]["auprc"] = ap.item() if ap.numel()>0 else None |
| out[i]["auroc"] = auc.item() if auc.numel()>0 else None |
| out[i]["n_pos"] = n_pos |
| out[i]["n_neg"] = n_neg |
| out[i]["precision"] = precision.numpy() if precision.numel()>0 else None |
| out[i]["recall"] = recall.numpy() if recall.numel()>0 else None |
| out[i]["auprc_thresholds"] = ap_thresholds.numpy() if ap_thresholds.numel()>0 else None |
| out[i]["auc_thresholds"] = auc_thresolds.numpy() if auc_thresolds.numel()>0 else None |
| out[i]["tpr"] = tpr |
| out[i]["fpr"] = fpr |
| |
| |
| summary_rows = [] |
| for d in out: |
| summary_rows.append({ |
| "ID": d["ID"], |
| "loss": d.get("loss"), |
| "auprc": d.get("auprc"), |
| "auroc": d.get("auroc"), |
| "n_pos": d.get("n_pos"), |
| "n_neg": d.get("n_neg"), |
| }) |
| save_path = output_dir / "summary.csv" |
| pd.DataFrame(summary_rows).to_csv(output_dir / "summary.csv", index=False) |
| |
| log.info(f"Saved eval/predict results to {save_path}") |
| |
| test_metrics = trainer.callback_metrics |
|
|
| |
| metric_dict = {**test_metrics} |
|
|
| return metric_dict, object_dict |
|
|
|
|
| @hydra.main( |
| version_base="1.3", config_path=str(root / "configs"), config_name="eval.yaml" |
| ) |
| def main(cfg: DictConfig) -> None: |
| """Main entry point for evaluation. |
| |
| :param cfg: DictConfig configuration composed by Hydra. |
| """ |
| |
| |
| extras(cfg) |
|
|
| h100_settings() |
|
|
| |
| metric_dict, _ = predict(cfg) |
|
|
| |
| metric_value = get_metric_value( |
| metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") |
| ) |
|
|
| |
| return metric_value |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|