|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import time |
|
from abc import ABC, abstractmethod |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from se3_transformer.runtime.loggers import Logger |
|
from se3_transformer.runtime.metrics import MeanAbsoluteError |
|
|
|
|
|
class BaseCallback(ABC): |
|
def on_fit_start(self, optimizer, args): |
|
pass |
|
|
|
def on_fit_end(self): |
|
pass |
|
|
|
def on_epoch_end(self): |
|
pass |
|
|
|
def on_batch_start(self): |
|
pass |
|
|
|
def on_validation_step(self, input, target, pred): |
|
pass |
|
|
|
def on_validation_end(self, epoch=None): |
|
pass |
|
|
|
def on_checkpoint_load(self, checkpoint): |
|
pass |
|
|
|
def on_checkpoint_save(self, checkpoint): |
|
pass |
|
|
|
|
|
class LRSchedulerCallback(BaseCallback): |
|
def __init__(self, logger: Optional[Logger] = None): |
|
self.logger = logger |
|
self.scheduler = None |
|
|
|
@abstractmethod |
|
def get_scheduler(self, optimizer, args): |
|
pass |
|
|
|
def on_fit_start(self, optimizer, args): |
|
self.scheduler = self.get_scheduler(optimizer, args) |
|
|
|
def on_checkpoint_load(self, checkpoint): |
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
def on_checkpoint_save(self, checkpoint): |
|
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() |
|
|
|
def on_epoch_end(self): |
|
if self.logger is not None: |
|
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch) |
|
self.scheduler.step() |
|
|
|
|
|
class QM9MetricCallback(BaseCallback): |
|
""" Logs the rescaled mean absolute error for QM9 regression tasks """ |
|
|
|
def __init__(self, logger, targets_std, prefix=''): |
|
self.mae = MeanAbsoluteError() |
|
self.logger = logger |
|
self.targets_std = targets_std |
|
self.prefix = prefix |
|
self.best_mae = float('inf') |
|
|
|
def on_validation_step(self, input, target, pred): |
|
self.mae(pred.detach(), target.detach()) |
|
|
|
def on_validation_end(self, epoch=None): |
|
mae = self.mae.compute() * self.targets_std |
|
logging.info(f'{self.prefix} MAE: {mae}') |
|
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch) |
|
self.best_mae = min(self.best_mae, mae) |
|
|
|
def on_fit_end(self): |
|
if self.best_mae != float('inf'): |
|
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae}) |
|
|
|
|
|
class QM9LRSchedulerCallback(LRSchedulerCallback): |
|
def __init__(self, logger, epochs): |
|
super().__init__(logger) |
|
self.epochs = epochs |
|
|
|
def get_scheduler(self, optimizer, args): |
|
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0 |
|
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr) |
|
|
|
|
|
class PerformanceCallback(BaseCallback): |
|
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'): |
|
self.batch_size = batch_size |
|
self.warmup_epochs = warmup_epochs |
|
self.epoch = 0 |
|
self.timestamps = [] |
|
self.mode = mode |
|
self.logger = logger |
|
|
|
def on_batch_start(self): |
|
if self.epoch >= self.warmup_epochs: |
|
self.timestamps.append(time.time() * 1000.0) |
|
|
|
def _log_perf(self): |
|
stats = self.process_performance_stats() |
|
for k, v in stats.items(): |
|
logging.info(f'performance {k}: {v}') |
|
|
|
self.logger.log_metrics(stats) |
|
|
|
def on_epoch_end(self): |
|
self.epoch += 1 |
|
|
|
def on_fit_end(self): |
|
if self.epoch > self.warmup_epochs: |
|
self._log_perf() |
|
self.timestamps = [] |
|
|
|
def process_performance_stats(self): |
|
timestamps = np.asarray(self.timestamps) |
|
deltas = np.diff(timestamps) |
|
throughput = (self.batch_size / deltas).mean() |
|
stats = { |
|
f"throughput_{self.mode}": throughput, |
|
f"latency_{self.mode}_mean": deltas.mean(), |
|
f"total_time_{self.mode}": timestamps[-1] - timestamps[0], |
|
} |
|
for level in [90, 95, 99]: |
|
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)}) |
|
|
|
return stats |
|
|