Spaces:
Paused
Paused
import torch | |
# Channel-wise mean and standard deviation of VAE encoder latents | |
STATS = { | |
"mean": torch.Tensor([ | |
-0.06730895953510081, | |
-0.038011381506090416, | |
-0.07477820912866141, | |
-0.05565264470995561, | |
0.012767231469026969, | |
-0.04703542746246419, | |
0.043896967884726704, | |
-0.09346305707025976, | |
-0.09918314763016893, | |
-0.008729793427399178, | |
-0.011931556316503654, | |
-0.0321993391887285, | |
]), | |
"std": torch.Tensor([ | |
0.9263795028493863, | |
0.9248894543193766, | |
0.9393059390890617, | |
0.959253732819592, | |
0.8244560132752793, | |
0.917259975397747, | |
0.9294154431013696, | |
1.3720942357788521, | |
0.881393668867029, | |
0.9168315692124348, | |
0.9185249279345552, | |
0.9274757570805041, | |
]), | |
} | |
def dit_latents_to_vae_latents(dit_outputs: torch.Tensor) -> torch.Tensor: | |
"""Unnormalize latents output by Mochi's DiT to be compatible with VAE. | |
Run this on sampled latents before calling the VAE decoder. | |
Args: | |
latents (torch.Tensor): [B, C_z, T_z, H_z, W_z], float | |
Returns: | |
torch.Tensor: [B, C_z, T_z, H_z, W_z], float | |
""" | |
mean = STATS["mean"][:, None, None, None] | |
std = STATS["std"][:, None, None, None] | |
assert dit_outputs.ndim == 5 | |
assert dit_outputs.size(1) == mean.size(0) == std.size(0) | |
return dit_outputs * std.to(dit_outputs) + mean.to(dit_outputs) | |
def vae_latents_to_dit_latents(vae_latents: torch.Tensor): | |
"""Normalize latents output by the VAE encoder to be compatible with Mochi's DiT. | |
E.g, for fine-tuning or video-to-video. | |
""" | |
mean = STATS["mean"][:, None, None, None] | |
std = STATS["std"][:, None, None, None] | |
assert vae_latents.ndim == 5 | |
assert vae_latents.size(1) == mean.size(0) == std.size(0) | |
return (vae_latents - mean.to(vae_latents)) / std.to(vae_latents) | |