Spaces:
Running
on
Zero
Running
on
Zero
import time | |
import click | |
import torch | |
import torchvision | |
from einops import rearrange | |
from safetensors.torch import load_file | |
from genmo.lib.utils import save_video | |
from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial, decode_latents, decode_latents_tiled_full | |
from genmo.mochi_preview.vae.models import Encoder, add_fourier_features | |
from genmo.mochi_preview.vae.latent_dist import LatentDistribution | |
from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents | |
def reconstruct(mochi_dir, video_path): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
decoder_factory = DecoderModelFactory( | |
model_path=f"{mochi_dir}/decoder.safetensors", | |
) | |
decoder = decoder_factory.get_model(world_size=1, device_id=0, local_rank=0) | |
config = dict( | |
prune_bottlenecks=[False, False, False, False, False], | |
has_attentions=[False, True, True, True, True], | |
affine=True, | |
bias=True, | |
input_is_conv_1x1=True, | |
padding_mode="replicate", | |
) | |
# Create VAE encoder | |
encoder = Encoder( | |
in_channels=15, | |
base_channels=64, | |
channel_multipliers=[1, 2, 4, 6], | |
num_res_blocks=[3, 3, 4, 6, 3], | |
latent_dim=12, | |
temporal_reductions=[1, 2, 3], | |
spatial_reductions=[2, 2, 2], | |
**config, | |
) | |
device = torch.device("cuda:0") | |
encoder = encoder.to(device, memory_format=torch.channels_last_3d) | |
encoder.load_state_dict(load_file(f"{mochi_dir}/encoder.safetensors")) | |
encoder.eval() | |
video, _, metadata = torchvision.io.read_video(video_path, output_format="THWC") | |
# video = video[:20] # Video can't be too long | |
fps = metadata["video_fps"] | |
video = rearrange(video, "t h w c -> c t h w") | |
video = video.unsqueeze(0) | |
assert video.dtype == torch.uint8 | |
# Convert to float in [-1, 1] range. | |
video = video.float() / 127.5 - 1.0 | |
video = video.to(device) | |
# print(f"Mean Intensity = {video.mean().item():.4f}, Standard Deviation = {video.std().item():.4f}, max ={video.max().item():.4f}, min ={video.min().item():.4f}") | |
video = add_fourier_features(video) | |
torch.cuda.synchronize() | |
# Encode video to latent | |
with torch.inference_mode(): | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
t0 = time.time() | |
ldist = encoder(video) | |
torch.cuda.synchronize() | |
print(f"Time to encode: {time.time() - t0:.2f}s") | |
t0 = time.time() | |
frames = decode_latents_tiled_spatial(decoder, ldist.sample(), num_tiles_w=1, num_tiles_h=1) | |
# frames = decode_latents_tiled_full(decoder, ldist.sample(), num_tiles_w=1, num_tiles_h=1) | |
# frames = decode_latents(decoder, ldist.sample()) | |
torch.cuda.synchronize() | |
print(f"Time to decode: {time.time() - t0:.2f}s") | |
t0 = time.time() | |
save_video(frames.cpu().numpy()[0], f"{video_path}.recon.mp4", fps=fps) | |
print(f"Time to save: {time.time() - t0:.2f}s") | |
if __name__ == "__main__": | |
reconstruct() | |