File size: 7,597 Bytes
7b33404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Script for using the model just for inference. 
"""
from typing import Any, Dict, List, Optional, Tuple

import hydra
from hydra.core.hydra_config import HydraConfig
import torch
import rootutils
import lightning as L
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
from pathlib import Path
import pandas as pd
from dpacman.classifier.loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
import pickle

root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

from dpacman.utils import (
    RankedLogger,
    extras,
    get_metric_value,
    instantiate_callbacks,
    instantiate_loggers,
    log_hyperparameters,
    task_wrapper,
)

log = RankedLogger(__name__, rank_zero_only=True)


def h100_settings():
    # Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff
    torch.set_float32_matmul_precision("high")  # or "medium" for even more speed

    # (optional; older PyTorch toggle)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

def flatten_preds(pred_batches):
    """
    Flatten what the model predicts, which includes: 
            "ids": batch["ID"],                           # list[str] or list
            "logits": logits.detach().cpu(),              # (B, Lmax) padded
            "valid": valid.detach().cpu(),           # (B, Lmax) booleans
            "labels" 
    """
    out = []
    for b in pred_batches:
        ids, logits, valid, labels = b["ids"], b["logits"], b["valid"], b["labels"]
        for i, id_ in enumerate(ids):
            L = int(valid[i].sum().item())     # strip padding
            trim_logits = logits[i, :L].numpy()
            out.append({"ID": id_, "logits": trim_logits, "labels": labels[i, :L].numpy()})
    return out

@task_wrapper
def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """trains model given checkpoint on a datamodule train set.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    :param cfg: DictConfig configuration composed by Hydra.
    :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
    """
    # set seed for random number generators in pytorch, numpy and python.random
    if cfg.get("seed"):
        L.seed_everything(cfg.seed, workers=True)

    log.info(f"Instantiating datamodule <{cfg.data_module._target_}>")
    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module)

    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(cfg.model)

    log.info("Instantiating callbacks...")
    callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))

    log.info("Instantiating loggers...")
    logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=logger
    )

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "callbacks": callbacks,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters!")
        log_hyperparameters(object_dict)

    if cfg.get("test"):
        log.info("Starting testing!")
        ckpt_path = cfg.ckpt_path
        if ckpt_path == "":
            log.warning("No ckpt path was passed! Cannot continue")
            return
        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
        
        pred_batches = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=True)
        out = flatten_preds(pred_batches)
        
        # make output dir
        output_dir = Path(HydraConfig.get().run.dir)
        save_path = output_dir / "predictions.pkl"
        with open(save_path, "wb") as f:
            pickle.dump(out, f)
        
        # iterate through out and recalculate AUC, AUPRC, loss - only if there are labels 
        # only if the user actually passed scores; otherwise don't bother 
        if not(datamodule.test_dataset.fake_scores):
            for i, d in enumerate(out):
                loss = calculate_loss(
                    torch.tensor(d["logits"]), torch.tensor(d["labels"]), None, None, alpha=cfg.model.alpha, gamma=cfg.model.gamma
                )
                # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
                ap, n_pos, n_neg, precision, recall, ap_thresholds = auprc_zeros_vs_ones_from_logits(
                    torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
                )
                auc, n_pos, n_neg, tpr, fpr, auc_thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
                    torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
                )
                out[i]["loss"] = loss.item() if loss.numel()>0 else None
                out[i]["auprc"] = ap.item() if ap.numel()>0 else None
                out[i]["auroc"] = auc.item() if auc.numel()>0 else None
                out[i]["n_pos"] = n_pos
                out[i]["n_neg"] = n_neg
                out[i]["precision"] = precision.numpy() if precision.numel()>0 else None
                out[i]["recall"] = recall.numpy() if recall.numel()>0 else None
                out[i]["auprc_thresholds"] = ap_thresholds.numpy() if ap_thresholds.numel()>0 else None
                out[i]["auc_thresholds"] = auc_thresolds.numpy() if auc_thresolds.numel()>0 else None
                out[i]["tpr"] = tpr
                out[i]["fpr"] = fpr       
            
            # Summary CSV (no big arrays inside)
            summary_rows = []
            for d in out:
                summary_rows.append({
                    "ID": d["ID"],
                    "loss": d.get("loss"),
                    "auprc": d.get("auprc"),
                    "auroc": d.get("auroc"),
                    "n_pos": d.get("n_pos"),
                    "n_neg": d.get("n_neg"),
                })
            save_path = output_dir / "summary.csv"
            pd.DataFrame(summary_rows).to_csv(output_dir / "summary.csv", index=False)
            # save it 
            log.info(f"Saved eval/predict results to {save_path}")
            
    test_metrics = trainer.callback_metrics

    # merge train and test metrics
    metric_dict = {**test_metrics}

    return metric_dict, object_dict


@hydra.main(
    version_base="1.3", config_path=str(root / "configs"), config_name="eval.yaml"
)
def main(cfg: DictConfig) -> None:
    """Main entry point for evaluation.

    :param cfg: DictConfig configuration composed by Hydra.
    """
    # apply extra utilities
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
    extras(cfg)

    h100_settings()  # try using settings for faster h100s training

    # train the model
    metric_dict, _ = predict(cfg)

    # safely retrieve metric value for hydra-based hyperparameter optimization
    metric_value = get_metric_value(
        metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
    )

    # return optimized metric
    return metric_value


if __name__ == "__main__":
    main()