| """Console logger utilities. |
| |
| Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py |
| Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging |
| """ |
|
|
| import logging |
|
|
| import fsspec |
| import lightning |
| import torch |
| from timm.scheduler import CosineLRScheduler |
|
|
|
|
| def fsspec_exists(filename): |
| """Check if a file exists using fsspec.""" |
| fs, _ = fsspec.core.url_to_fs(filename) |
| return fs.exists(filename) |
|
|
|
|
| def fsspec_listdir(dirname): |
| """Listdir in manner compatible with fsspec.""" |
| fs, _ = fsspec.core.url_to_fs(dirname) |
| return fs.ls(dirname) |
|
|
|
|
| def fsspec_mkdirs(dirname, exist_ok=True): |
| """Mkdirs in manner compatible with fsspec.""" |
| fs, _ = fsspec.core.url_to_fs(dirname) |
| fs.makedirs(dirname, exist_ok=exist_ok) |
|
|
|
|
| def print_nans(tensor, name): |
| if torch.isnan(tensor).any(): |
| print(name, tensor) |
|
|
|
|
| class CosineDecayWarmupLRScheduler( |
| CosineLRScheduler, |
| torch.optim.lr_scheduler._LRScheduler): |
| """Wrap timm.scheduler.CosineLRScheduler |
| Enables calling scheduler.step() without passing in epoch. |
| Supports resuming as well. |
| Adapted from: |
| https://github.com/HazyResearch/hyena-dna/blob/main/src/utils/optim/schedulers.py |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._last_epoch = -1 |
| self.step(epoch=0) |
|
|
| def step(self, epoch=None): |
| if epoch is None: |
| self._last_epoch += 1 |
| else: |
| self._last_epoch = epoch |
| |
| |
| |
| |
| |
| |
| |
| if self.t_in_epochs: |
| super().step(epoch=self._last_epoch) |
| else: |
| super().step_update(num_updates=self._last_epoch) |
|
|
|
|
| def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: |
| """Initializes multi-GPU-friendly python logger.""" |
|
|
| logger = logging.getLogger(name) |
| logger.setLevel(level) |
|
|
| |
| |
| for level in ('debug', 'info', 'warning', 'error', |
| 'exception', 'fatal', 'critical'): |
| setattr(logger, |
| level, |
| lightning.pytorch.utilities.rank_zero_only( |
| getattr(logger, level))) |
|
|
| return logger |
|
|