import torch import os import torch.nn.functional as F def save_checkpoint(model, filelocation, save_parallel = True): if save_parallel: torch.save(model.module.state_dict(), filelocation) else: torch.save(model.state_dict(), filelocation) def load_Checkpoint(fileLocation,model, load_cpu=False): if load_cpu: model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage)) else: model.load_state_dict(torch.load(fileLocation)) return model def writeLog(logList, filename): with open(filename, 'w') as outfile: outfile.write("\n".join(logList)) def kl_loss(mu, logvar): return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()