Spaces:
Running
on
L4
Running
on
L4
from collections import defaultdict | |
from collections import deque | |
import datetime | |
import torch | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
""" | |
def __init__(self, window_size=20): | |
self.deque = deque(maxlen=window_size) | |
self.series = [] | |
self.total = 0.0 | |
self.count = 0 | |
def update(self, value): | |
self.deque.append(value) | |
self.series.append(value) | |
self.count += 1 | |
self.total += value | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque)) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / self.count | |
class Stats: | |
def __init__(self, num_steps=None, num_epochs=None, steps_per_epoch=None, stats_to_print=None): | |
self.step = self.epoch = 0 | |
if num_steps is not None: | |
self.num_steps = num_steps | |
else: | |
self.num_steps = num_epochs * steps_per_epoch | |
self.stats = { | |
"train": defaultdict(SmoothedValue), | |
} | |
self.stats_to_print = {k: set(v) for k, v in stats_to_print.items()} | |
def to_dict(self): | |
return self.__dict__ | |
def load_dict(self, dict): | |
for key, val in dict.items(): | |
setattr(self, key, val) | |
def update(self, split, step, epoch, dict): | |
self.step = step | |
self.epoch = epoch | |
for k, v in dict.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.stats[split][k].update(v) | |
def update_stats_to_print(self, split, stats_to_print): | |
self.stats_to_print[split].update(stats_to_print) | |
def get_summary(self, split): | |
if split == "train": | |
completion_pct = self.step / self.num_steps * 100 | |
eta_seconds = self.stats[split].get("time").global_avg * (self.num_steps - self.step) | |
eta_string = datetime.timedelta(seconds=int(eta_seconds)) | |
s = "[{}/{}, {:.1f}%] eta: {}, ".format(self.step, self.num_steps, completion_pct, eta_string) | |
else: | |
s = f"[Validation, epoch {self.epoch + 1}] " | |
return s + ", ".join(f"{stat}: {self.stats[split].get(stat).median:.5f}" for stat in self.stats_to_print[split]) | |
def write_tensorboard(self, summary_writer, split): | |
summary_writer.add_scalar(f"{split}/epoch", self.epoch + 1, self.step) | |
for stat in self.stats_to_print[split]: | |
summary_writer.add_scalar(f"{split}/{stat}", self.stats[split].get(stat).median, self.step) | |
def is_best(self): | |
return True | |