File size: 1,201 Bytes
96257b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Container for latent space posterior."""

import torch


class LatentDistribution:
    def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
        """Initialize latent distribution.



        Args:

            mean: Mean of the distribution. Shape: [B, C, T, H, W].

            logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].

        """
        assert mean.shape == logvar.shape
        self.mean = mean
        self.logvar = logvar

    def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
        if temperature == 0.0:
            return self.mean

        if noise is None:
            noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
        else:
            assert noise.device == self.mean.device
            noise = noise.to(self.mean.dtype)

        if temperature != 1.0:
            raise NotImplementedError(f"Temperature {temperature} is not supported.")

        # Just Gaussian sample with no scaling of variance.
        return noise * torch.exp(self.logvar * 0.5) + self.mean

    def mode(self):
        return self.mean