|
import torch
|
|
import torch.nn.functional as F
|
|
from .hessian_penalty import hessian_penalty
|
|
from .mmd import compute_mmd
|
|
|
|
|
|
def compute_rc_loss(model, batch, use_txt_output=False):
|
|
x = batch["x"]
|
|
output = batch["output"]
|
|
mask = batch["mask"]
|
|
if use_txt_output:
|
|
output = batch["txt_output"]
|
|
gtmasked = x.permute(0, 3, 1, 2)[mask]
|
|
outmasked = output.permute(0, 3, 1, 2)[mask]
|
|
|
|
loss = F.mse_loss(gtmasked, outmasked, reduction='mean')
|
|
return loss
|
|
|
|
|
|
def compute_rcxyz_loss(model, batch, use_txt_output=False):
|
|
x = batch["x_xyz"]
|
|
output = batch["output_xyz"]
|
|
mask = batch["mask"]
|
|
if use_txt_output:
|
|
output = batch["txt_output_xyz"]
|
|
gtmasked = x.permute(0, 3, 1, 2)[mask]
|
|
outmasked = output.permute(0, 3, 1, 2)[mask]
|
|
|
|
loss = F.mse_loss(gtmasked, outmasked, reduction='mean')
|
|
return loss
|
|
|
|
|
|
def compute_vel_loss(model, batch, use_txt_output=False):
|
|
x = batch["x"]
|
|
output = batch["output"]
|
|
if use_txt_output:
|
|
output = batch["txt_output"]
|
|
gtvel = (x[..., 1:] - x[..., :-1])
|
|
outputvel = (output[..., 1:] - output[..., :-1])
|
|
|
|
mask = batch["mask"][..., 1:]
|
|
|
|
gtvelmasked = gtvel.permute(0, 3, 1, 2)[mask]
|
|
outvelmasked = outputvel.permute(0, 3, 1, 2)[mask]
|
|
|
|
loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean')
|
|
return loss
|
|
|
|
|
|
def compute_velxyz_loss(model, batch, use_txt_output=False):
|
|
x = batch["x_xyz"]
|
|
output = batch["output_xyz"]
|
|
if use_txt_output:
|
|
output = batch["txt_output_xyz"]
|
|
gtvel = (x[..., 1:] - x[..., :-1])
|
|
outputvel = (output[..., 1:] - output[..., :-1])
|
|
|
|
mask = batch["mask"][..., 1:]
|
|
|
|
gtvelmasked = gtvel.permute(0, 3, 1, 2)[mask]
|
|
outvelmasked = outputvel.permute(0, 3, 1, 2)[mask]
|
|
|
|
loss = F.mse_loss(gtvelmasked, outvelmasked, reduction='mean')
|
|
return loss
|
|
|
|
|
|
def compute_hp_loss(model, batch):
|
|
loss = hessian_penalty(model.return_latent, batch, seed=torch.random.seed())
|
|
return loss
|
|
|
|
|
|
def compute_mmd_loss(model, batch):
|
|
z = batch["z"]
|
|
true_samples = torch.randn(z.shape, requires_grad=False, device=model.device)
|
|
loss = compute_mmd(true_samples, z)
|
|
return loss
|
|
|
|
|
|
_matching_ = {"rc": compute_rc_loss, "hp": compute_hp_loss,
|
|
"mmd": compute_mmd_loss, "rcxyz": compute_rcxyz_loss,
|
|
"vel": compute_vel_loss, "velxyz": compute_velxyz_loss}
|
|
|
|
|
|
def get_loss_function(ltype):
|
|
return _matching_[ltype]
|
|
|
|
|
|
def get_loss_names():
|
|
return list(_matching_.keys())
|
|
|