OmniSVG's picture
Upload 80 files
c1ce505 verified
from collections import defaultdict
from collections import deque
import datetime
import torch
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20):
self.deque = deque(maxlen=window_size)
self.series = []
self.total = 0.0
self.count = 0
def update(self, value):
self.deque.append(value)
self.series.append(value)
self.count += 1
self.total += value
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque))
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
class Stats:
def __init__(self, num_steps=None, num_epochs=None, steps_per_epoch=None, stats_to_print=None):
self.step = self.epoch = 0
if num_steps is not None:
self.num_steps = num_steps
else:
self.num_steps = num_epochs * steps_per_epoch
self.stats = {
"train": defaultdict(SmoothedValue),
}
self.stats_to_print = {k: set(v) for k, v in stats_to_print.items()}
def to_dict(self):
return self.__dict__
def load_dict(self, dict):
for key, val in dict.items():
setattr(self, key, val)
def update(self, split, step, epoch, dict):
self.step = step
self.epoch = epoch
for k, v in dict.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.stats[split][k].update(v)
def update_stats_to_print(self, split, stats_to_print):
self.stats_to_print[split].update(stats_to_print)
def get_summary(self, split):
if split == "train":
completion_pct = self.step / self.num_steps * 100
eta_seconds = self.stats[split].get("time").global_avg * (self.num_steps - self.step)
eta_string = datetime.timedelta(seconds=int(eta_seconds))
s = "[{}/{}, {:.1f}%] eta: {}, ".format(self.step, self.num_steps, completion_pct, eta_string)
else:
s = f"[Validation, epoch {self.epoch + 1}] "
return s + ", ".join(f"{stat}: {self.stats[split].get(stat).median:.5f}" for stat in self.stats_to_print[split])
def write_tensorboard(self, summary_writer, split):
summary_writer.add_scalar(f"{split}/epoch", self.epoch + 1, self.step)
for stat in self.stats_to_print[split]:
summary_writer.add_scalar(f"{split}/{stat}", self.stats[split].get(stat).median, self.step)
def is_best(self):
return True