File size: 2,581 Bytes
fe64bad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import torch.nn.functional as F
from .hessian_penalty import hessian_penalty
from .mmd import compute_mmd
def compute_rc_loss(model, batch, use_txt_output=False):
x = batch["x"]
output = batch["output"]
mask = batch["mask"]
if use_txt_output:
output = batch["txt_output"]
gtmasked = x.permute(0, 3, 1, 2)[mask]
outmasked = output.permute(0, 3, 1, 2)[mask]
loss = F.mse_loss(gtmasked, outmasked, reduction='mean')
return loss
def compute_rcxyz_loss(model, batch, use_txt_output=False):
x = batch["x_xyz"]
output = batch["output_xyz"]
mask = batch["mask"]
if use_txt_output:
output = batch["txt_output_xyz"]
gtmasked = x.permute(0, 3, 1, 2)[mask]
outmasked = output.permute(0, 3, 1, 2)[mask]
loss = F.mse_loss(gtmasked, outmasked, reduction='mean')
return loss
def compute_vel_loss(model, batch, use_txt_output=False):
x = batch["x"]
output = batch["output"]
if use_txt_output:
output = batch["txt_output"]
gtvel = (x[..., 1:] - x[..., :-1])
outputvel = (output[..., 1:] - output[..., :-1])
mask = batch["mask"][..., 1:]
gtvelmasked = gtvel.permute(0, 3, 1, 2)[mask]
outvelmasked = outputvel.permute(0, 3, 1, 2)[mask]
loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean')
return loss
def compute_velxyz_loss(model, batch, use_txt_output=False):
x = batch["x_xyz"]
output = batch["output_xyz"]
if use_txt_output:
output = batch["txt_output_xyz"]
gtvel = (x[..., 1:] - x[..., :-1])
outputvel = (output[..., 1:] - output[..., :-1])
mask = batch["mask"][..., 1:]
gtvelmasked = gtvel.permute(0, 3, 1, 2)[mask]
outvelmasked = outputvel.permute(0, 3, 1, 2)[mask]
loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean')
return loss
def compute_hp_loss(model, batch):
loss = hessian_penalty(model.return_latent, batch, seed=torch.random.seed())
return loss
def compute_mmd_loss(model, batch):
z = batch["z"]
true_samples = torch.randn(z.shape, requires_grad=False, device=model.device)
loss = compute_mmd(true_samples, z)
return loss
_matching_ = {"rc": compute_rc_loss, "hp": compute_hp_loss,
"mmd": compute_mmd_loss, "rcxyz": compute_rcxyz_loss,
"vel": compute_vel_loss, "velxyz": compute_velxyz_loss}
def get_loss_function(ltype):
return _matching_[ltype]
def get_loss_names():
return list(_matching_.keys())
|