Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,656 Bytes
8ed2f16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
import torch.nn as nn
from einops import rearrange
from rendering import RenderingClass
from .lpips_loss import LPIPS
class LPIPSithTVLoss(nn.Module):
def __init__(self, lpips_ckpt, vgg_ckpt, device, render_network_pkl, std_v, mean_v, ws_avg_pkl, logvar_init=0.0,
kl_weight=1e-5, faceloss_weight=1.0,
pixelloss_weight=1.0, depthloss_weight=0, face_feature_weight=1.0,
disc_num_layers=3, disc_in_channels=3, percep_factor=1.0, disc_weight=1.0,
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, imgloss_weight=0.,
disc_loss="hinge", max_bs=None, latent_tv_weight=2e-3, use_double_norm=False, use_max_min=False,
use_std_mean=True):
super().__init__()
self.kl_weight = kl_weight
self.face_feature_weight = face_feature_weight
self.pixel_weight = pixelloss_weight
self.face_weight = faceloss_weight
self.depthloss_weight = depthloss_weight
self.latent_tv_weight = latent_tv_weight
self.Rendering = RenderingClass(device, render_network_pkl, ws_avg_pkl)
self.perceptual_loss = LPIPS(lpips_ckpt, vgg_ckpt).eval().to(device)
self.perceptual_weight = perceptual_weight
self.imgloss_weight = imgloss_weight
self.std_v = torch.load(std_v)
self.mean_v = torch.load(mean_v)
self.use_std_mean = use_std_mean
self.std_v = self.std_v.reshape(1, -1, 1, 1, 1).to(device)
self.mean_v = self.mean_v.reshape(1, -1, 1, 1, 1).to(device)
self.percep_factor = percep_factor
# self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
def denormolize(self, inputs_rendering, reconstructions_rendering):
inputs_rendering = inputs_rendering * self.std_v + self.mean_v
reconstructions_rendering = reconstructions_rendering * self.std_v + self.mean_v
return inputs_rendering, reconstructions_rendering
def forward(self, inputs, reconstructions, posteriors, vert_values,
split="train"):
inputs_rendering = rearrange(inputs, 'b c t h w -> b t c h w')
reconstructions_rendering = rearrange(reconstructions, 'b c t h w -> b t c h w')
# if inputs.dim() == 5:
# inputs = rearrange(inputs, 'b c t h w -> (b t) c h w')
# if reconstructions.dim() == 5:
# reconstructions = rearrange(reconstructions, 'b c t h w -> (b t) c h w')
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = torch.sum(rec_loss) / rec_loss.shape[0]
log = {"{}/rec_loss".format(split): rec_loss.detach().mean()}
loss = self.pixel_weight * rec_loss
inputs_rendering_original, reconstructions_rendering_original = self.denormolize(inputs_rendering,
reconstructions_rendering)
inputs_img, recon_img, inputs_depth, recon_depth, inputs_face_feature, recon_face_feature = self.Rendering.rendering_for_training(
reconstructions_rendering_original.contiguous(),
inputs_rendering_original.contiguous(), vert_values)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(recon_img.contiguous(), inputs_img.contiguous())
p_loss = torch.sum(p_loss) / p_loss.shape[0]
loss += self.perceptual_weight * p_loss
log["{}/percep_loss".format(split)] = p_loss.detach().mean()
if self.depthloss_weight > 0:
rec_loss_depth = torch.abs(recon_depth.contiguous() - inputs_depth.contiguous())
rec_loss_depth = torch.sum(rec_loss_depth) / rec_loss_depth.shape[0]
loss += self.depthloss_weight * rec_loss_depth
log["{}/depth_loss".format(split)] = rec_loss_depth.detach().mean()
if self.latent_tv_weight > 0:
latent = posteriors.mean
latent_tv_y = torch.abs(latent[:, :, :-1] - latent[:, :, 1:]).sum() / latent.shape[0]
latent_tv_x = torch.abs(latent[:, :, :, :-1] - latent[:, :, :, 1:]).sum() / latent.shape[0]
latent_tv_loss = latent_tv_y + latent_tv_x
loss += latent_tv_loss * self.latent_tv_weight
log["{}/tv_loss".format(split)] = latent_tv_loss.detach().mean()
kl_loss = posteriors.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
loss += kl_loss * self.kl_weight
log["{}/kl_loss".format(split)] = kl_loss.detach().mean()
log["{}/total_loss".format(split)] = loss.detach().mean()
return loss, log
|