File size: 1,992 Bytes
96257b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch

# Channel-wise mean and standard deviation of VAE encoder latents
STATS = {
    "mean": torch.Tensor([
        -0.06730895953510081,
        -0.038011381506090416,
        -0.07477820912866141,
        -0.05565264470995561,
        0.012767231469026969,
        -0.04703542746246419,
        0.043896967884726704,
        -0.09346305707025976,
        -0.09918314763016893,
        -0.008729793427399178,
        -0.011931556316503654,
        -0.0321993391887285,
    ]),
    "std": torch.Tensor([
        0.9263795028493863,
        0.9248894543193766,
        0.9393059390890617,
        0.959253732819592,
        0.8244560132752793,
        0.917259975397747,
        0.9294154431013696,
        1.3720942357788521,
        0.881393668867029,
        0.9168315692124348,
        0.9185249279345552,
        0.9274757570805041,
    ]),
}

def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor:
    """Unnormalize latents output by Mochi's DiT to be compatible with VAE.

    Run this on sampled latents before calling the VAE decoder.



    Args:

        latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float



    Returns:

        torch.Tensor: [B, C_z, T_z, H_z, W_z], float

    """
    mean = STATS["mean"][:, None, None, None]
    std = STATS["std"][:, None, None, None]

    assert dit_outputs.ndim == 5
    assert dit_outputs.size(1) == mean.size(0) == std.size(0)
    return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs)


def vae_latents_to_dit_latents(vae_latents: torch.Tensor):
    """Normalize latents output by the VAE encoder to be compatible with Mochi's DiT.

    E.g, for fine-tuning or video-to-video.

    """
    mean = STATS["mean"][:, None, None, None]
    std = STATS["std"][:, None, None, None]

    assert vae_latents.ndim == 5
    assert vae_latents.size(1) == mean.size(0) == std.size(0)
    return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents)