from typing import Any, Dict, List, Optional, Tuple import hydra import torch import rootutils import lightning as L from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from dpacman.utils import ( RankedLogger, extras, get_metric_value, instantiate_callbacks, instantiate_loggers, log_hyperparameters, task_wrapper, ) log = RankedLogger(__name__, rank_zero_only=True) def h100_settings(): # Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff torch.set_float32_matmul_precision("high") # or "medium" for even more speed # (optional; older PyTorch toggle) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """trains model given checkpoint on a datamodule train set. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. :param cfg: DictConfig configuration composed by Hydra. :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. """ # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): L.seed_everything(cfg.seed, workers=True) log.info(f"Instantiating datamodule <{cfg.data_module._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module) log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( cfg.trainer, callbacks=callbacks, logger=logger ) object_dict = { "cfg": cfg, "datamodule": datamodule, "model": model, "callbacks": callbacks, "logger": logger, "trainer": trainer, } if logger: log.info("Logging hyperparameters!") log_hyperparameters(object_dict) if cfg.get("train"): log.info("Starting training!") trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) train_metrics = trainer.callback_metrics log.info("Training completed! Ready for testing.") if cfg.get("test"): log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "": log.warning("Best ckpt not found! Using current weights for testing...") ckpt_path = None trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best ckpt path: {ckpt_path}") test_metrics = trainer.callback_metrics # merge train and test metrics metric_dict = {**train_metrics, **test_metrics} return metric_dict, object_dict @hydra.main( version_base="1.3", config_path=str(root / "configs"), config_name="train.yaml" ) def main(cfg: DictConfig) -> None: """Main entry point for evaluation. :param cfg: DictConfig configuration composed by Hydra. """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) extras(cfg) h100_settings() # try using settings for faster h100s training # train the model metric_dict, _ = train(cfg) # safely retrieve metric value for hydra-based hyperparameter optimization metric_value = get_metric_value( metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") ) # return optimized metric return metric_value if __name__ == "__main__": main()