from collections import defaultdict from collections import deque import datetime import torch class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20): self.deque = deque(maxlen=window_size) self.series = [] self.total = 0.0 self.count = 0 def update(self, value): self.deque.append(value) self.series.append(value) self.count += 1 self.total += value @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque)) return d.mean().item() @property def global_avg(self): return self.total / self.count class Stats: def __init__(self, num_steps=None, num_epochs=None, steps_per_epoch=None, stats_to_print=None): self.step = self.epoch = 0 if num_steps is not None: self.num_steps = num_steps else: self.num_steps = num_epochs * steps_per_epoch self.stats = { "train": defaultdict(SmoothedValue), } self.stats_to_print = {k: set(v) for k, v in stats_to_print.items()} def to_dict(self): return self.__dict__ def load_dict(self, dict): for key, val in dict.items(): setattr(self, key, val) def update(self, split, step, epoch, dict): self.step = step self.epoch = epoch for k, v in dict.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.stats[split][k].update(v) def update_stats_to_print(self, split, stats_to_print): self.stats_to_print[split].update(stats_to_print) def get_summary(self, split): if split == "train": completion_pct = self.step / self.num_steps * 100 eta_seconds = self.stats[split].get("time").global_avg * (self.num_steps - self.step) eta_string = datetime.timedelta(seconds=int(eta_seconds)) s = "[{}/{}, {:.1f}%] eta: {}, ".format(self.step, self.num_steps, completion_pct, eta_string) else: s = f"[Validation, epoch {self.epoch + 1}] " return s + ", ".join(f"{stat}: {self.stats[split].get(stat).median:.5f}" for stat in self.stats_to_print[split]) def write_tensorboard(self, summary_writer, split): summary_writer.add_scalar(f"{split}/epoch", self.epoch + 1, self.step) for stat in self.stats_to_print[split]: summary_writer.add_scalar(f"{split}/{stat}", self.stats[split].get(stat).median, self.step) def is_best(self): return True