RT-MPINet / helperFunctions.py
3ZadeSSG's picture
initial commit
ff00a24
raw
history blame contribute delete
766 Bytes
import torch
import os
import torch.nn.functional as F
def save_checkpoint(model, filelocation, save_parallel = True):
if save_parallel:
torch.save(model.module.state_dict(), filelocation)
else:
torch.save(model.state_dict(), filelocation)
def load_Checkpoint(fileLocation,model, load_cpu=False):
if load_cpu:
model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage))
else:
model.load_state_dict(torch.load(fileLocation))
return model
def writeLog(logList, filename):
with open(filename, 'w') as outfile:
outfile.write("\n".join(logList))
def kl_loss(mu, logvar):
return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()