Spaces:
Paused
Paused
"""Container for latent space posterior.""" | |
import torch | |
class LatentDistribution: | |
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): | |
"""Initialize latent distribution. | |
Args: | |
mean: Mean of the distribution. Shape: [B, C, T, H, W]. | |
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W]. | |
""" | |
assert mean.shape == logvar.shape | |
self.mean = mean | |
self.logvar = logvar | |
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None): | |
if temperature == 0.0: | |
return self.mean | |
if noise is None: | |
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator) | |
else: | |
assert noise.device == self.mean.device | |
noise = noise.to(self.mean.dtype) | |
if temperature != 1.0: | |
raise NotImplementedError(f"Temperature {temperature} is not supported.") | |
# Just Gaussian sample with no scaling of variance. | |
return noise * torch.exp(self.logvar * 0.5) + self.mean | |
def mode(self): | |
return self.mean | |