Spaces:
Running
on
L4
Running
on
L4
class Ema(): | |
def __init__(self, model, decay): | |
self.model = model | |
self.decay = decay | |
self.shadow = {} | |
self.backup = {} | |
def register(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and not ('argument_fcn' in name or 'argument_decoder' in name): | |
self.shadow[name] = param.data.clone() | |
def update(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and not ('argument_fcn' in name or 'argument_decoder' in name): | |
assert name in self.shadow | |
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] | |
self.shadow[name] = new_average.clone() | |
def apply_shadow(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and not ('argument_fcn' in name or 'argument_decoder' in name): | |
assert name in self.shadow | |
self.backup[name] = param.data | |
param.data = self.shadow[name] | |
def restore(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad and not ('argument_fcn' in name or 'argument_decoder' in name): | |
assert name in self.backup | |
param.data = self.backup[name] | |
self.backup = {} |