PusaV1 / demos /cli_test_ti2v_release.py
rahul7star's picture
Migrated from GitHub
96257b2 verified
#! /usr/bin/env python
import json
import os
import time
import click
import numpy as np
import torch
from genmo.lib.progress import progress_bar
from genmo.lib.utils import save_video
from genmo.mochi_preview.pipelines_ti2v_release import (
DecoderModelFactory,
EncoderModelFactory,
DitModelFactory,
MochiMultiGPUPipeline,
MochiSingleGPUPipeline,
T5ModelFactory,
linear_quadratic_schedule,
)
import torch
from torch.utils.data import Dataset, DataLoader
import random
import string
from lightning.pytorch import LightningDataModule
from genmo.mochi_preview.vae.models import Encoder, add_fourier_features
from genmo.mochi_preview.vae.latent_dist import LatentDistribution
import torchvision
from einops import rearrange
from safetensors.torch import load_file
from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial, decode_latents, decode_latents_tiled_full
from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents
pipeline = None
model_dir_path = None
num_gpus = torch.cuda.device_count()
cpu_offload = False
dit_path = None
def configure_model(model_dir_path_, dit_path_, cpu_offload_):
global model_dir_path, dit_path, cpu_offload
model_dir_path = model_dir_path_
dit_path = dit_path_
cpu_offload = cpu_offload_
def load_model():
global num_gpus, pipeline, model_dir_path, dit_path
if pipeline is None:
MOCHI_DIR = model_dir_path
print(f"Launching with {num_gpus} GPUs. If you want to force single GPU mode use CUDA_VISIBLE_DEVICES=0.")
klass = MochiSingleGPUPipeline if num_gpus == 1 else MochiMultiGPUPipeline
kwargs = dict(
text_encoder_factory=T5ModelFactory(),
dit_factory=DitModelFactory(
model_path=dit_path,
model_dtype="bf16"
),
decoder_factory=DecoderModelFactory(
model_path=f"{MOCHI_DIR}/decoder.safetensors",
),
encoder_factory=EncoderModelFactory(
model_path=f"{MOCHI_DIR}/encoder.safetensors",
),
)
if num_gpus > 1:
assert not cpu_offload, "CPU offload not supported in multi-GPU mode"
kwargs["world_size"] = num_gpus
else:
kwargs["cpu_offload"] = cpu_offload
# kwargs["decode_type"] = "tiled_full"
kwargs["decode_type"] = "tiled_spatial"
pipeline = klass(**kwargs)
def generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_inference_steps,
data_path,
input_image=None,
noise_multiplier=0,
):
load_model()
global dit_path
# sigma_schedule should be a list of floats of length (num_inference_steps + 1),
# such that sigma_schedule[0] == 1.0 and sigma_schedule[-1] == 0.0 and monotonically decreasing.
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
# cfg_schedule should be a list of floats of length num_inference_steps.
# For simplicity, we just use the same cfg scale at all timesteps,
# but more optimal schedules may use varying cfg, e.g:
# [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
cfg_schedule = [cfg_scale] * num_inference_steps
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"sigma_schedule": sigma_schedule,
"cfg_schedule": cfg_schedule,
"num_inference_steps": num_inference_steps,
# We *need* flash attention to batch cfg
# and it's only worth doing in a high-memory regime (assume multiple GPUs)
"batch_cfg": False,
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": seed,
"data_path": data_path,
"noise_multiplier": noise_multiplier,
}
# Handle different input types
if input_image is not None:
# if "tensor" in input_image:
# Check if this is an image tensor (for image conditioning) or latent tensor
# if len(input_image["tensor"].shape) == 4: # [B, C, H, W] - image tensor
# This is an image tensor, prepare it for conditioning
# cond_position = input_image.get("cond_position", 0)
args["condition_image"] = input_image["tensor"]
args["condition_frame_idx"] = input_image["cond_position"]
# else: # Latent tensor
# args["input_image"] = input_image["tensor"]
# print(args)
with progress_bar(type="tqdm"):
final_frames = pipeline(**args)
final_frames = final_frames[0]
assert isinstance(final_frames, np.ndarray)
assert final_frames.dtype == np.float32
# Create a results directory based on model name and timestamp
model_name = os.path.basename(dit_path.split('/')[-2])
checkpoint_name = dit_path.split('/')[-1].split('train_loss')[0]
# Use datetime format for timestamp_dir
from datetime import datetime
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
if input_image is not None:
cond_position = input_image["cond_position"]
else:
cond_position = ""
results_base_dir = "/home/dyvm6xra/dyvm6xrauser02/raphael/mochi-1-preview/models/video_test_demos_results"
results_dir = os.path.join(results_base_dir, f"{model_name}_{checkpoint_name}_github_user_demo_{cond_position}pos_{num_inference_steps}steps_crop_{noise_multiplier}sigma")
os.makedirs(results_dir, exist_ok=True)
# Extract filename from input_image if available
filename_prefix = ""
if isinstance(input_image, dict) and "filename" in input_image:
filename_prefix = f"{os.path.basename(input_image['filename']).split('.')[0]}_"
output_path = os.path.join(
results_dir,
f"{filename_prefix}{timestamp_str}.mp4"
)
save_video(final_frames, output_path)
json_path = os.path.splitext(output_path)[0] + ".json"
# Save args to JSON but remove input_image tensor and convert non-serializable objects
json_args = args.copy()
# Handle input_image for JSON serialization
if "input_image" in json_args:
json_args["input_image"] = None
# Handle condition_image for JSON serialization
if "condition_image" in json_args:
json_args["condition_image"] = "Image tensor (removed for JSON)"
if isinstance(input_image, dict):
json_args["input_filename"] = input_image.get("filename", None)
if "cond_position" in input_image:
json_args["condition_frame_idx"] = input_image["cond_position"]
# Convert sigma_schedule and cfg_schedule from tensors to lists if needed
if isinstance(json_args["sigma_schedule"], torch.Tensor):
json_args["sigma_schedule"] = json_args["sigma_schedule"].tolist()
if isinstance(json_args["cfg_schedule"], torch.Tensor):
json_args["cfg_schedule"] = json_args["cfg_schedule"].tolist()
# Handle prompt if it's a tensor or other non-serializable object
if not isinstance(json_args["prompt"], (str, type(None))):
if hasattr(json_args["prompt"], "tolist"):
json_args["prompt"] = "Tensor prompt (converted to string for JSON)"
else:
json_args["prompt"] = str(json_args["prompt"])
# Handle negative_prompt if it's a tensor
if not isinstance(json_args["negative_prompt"], (str, type(None))):
if hasattr(json_args["negative_prompt"], "tolist"):
json_args["negative_prompt"] = "Tensor negative prompt (converted to string for JSON)"
else:
json_args["negative_prompt"] = str(json_args["negative_prompt"])
json.dump(json_args, open(json_path, "w"), indent=4)
return output_path
from textwrap import dedent
@click.command()
@click.option("--prompt", default="A man is playing the basketball", help="Prompt for video generation.")
@click.option("--negative_prompt", default="", help="Negative prompt for video generation.")
@click.option("--width", default=848, type=int, help="Width of the video.")
@click.option("--height", default=480, type=int, help="Height of the video.")
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
@click.option("--seed", default=1710977262, type=int, help="Random seed.")
@click.option("--cfg_scale", default=4.5, type=float, help="CFG Scale.")
@click.option("--num_steps", default=64, type=int, help="Number of inference steps.")
@click.option("--model_dir", required=True, help="Path to the model directory.")
@click.option("--dit_path", required=True, help="Path to the dit model directory.")
@click.option("--cpu_offload", is_flag=True, help="Whether to offload model to CPU")
@click.option("--data_path", required=True, default="/home/dyvm6xra/dyvm6xrauser02/data/vidgen1m", help="Path to the data directory.")
@click.option("--image_dir", default=None, help="Path to image or directory of images for conditioning.")
@click.option("--prompt_dir", default=None, help="Path to directory containing prompt text files.")
@click.option("--cond_position", default=0, type=int, help="Frame position to place the conditioning image, from 0 to 27.")
@click.option("--noise_multiplier", default=0, type=float, help="Noise multiplier for noise on the conditioning image.")
def generate_cli(
prompt, negative_prompt, width, height, num_frames, seed, cfg_scale, num_steps, model_dir,
dit_path, cpu_offload, data_path, image_dir, prompt_dir, cond_position, noise_multiplier
):
configure_model(model_dir, dit_path, cpu_offload)
# Case 1: Text to video generation
if image_dir is None:
# Check if we have a prompt directory to process multiple prompts
if prompt_dir is not None and os.path.isdir(prompt_dir):
prompt_files = [f for f in os.listdir(prompt_dir) if f.lower().endswith('.txt')]
prompt_files = prompt_files[210:] # TODO: Remove this
if not prompt_files:
click.echo(f"No prompt files found in {prompt_dir}")
return
click.echo(f"Found {len(prompt_files)} prompt files to process")
for i, prompt_file in enumerate(prompt_files):
file_path = os.path.join(prompt_dir, prompt_file)
click.echo(f"Processing prompt file {i+1}/{len(prompt_files)}: {file_path}")
# Read prompt from file
with open(file_path, 'r') as f:
file_prompt = f.read().strip()
click.echo(f"Using prompt: {file_prompt}")
with torch.inference_mode():
output = generate_video(
file_prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
data_path,
input_image=None,
noise_multiplier=noise_multiplier,
)
click.echo(f"Video generated at: {output}")
else:
# Process single prompt as before
click.echo("Running text-to-video generation with provided prompt")
with torch.inference_mode():
output = generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
data_path,
input_image=None,
)
click.echo(f"Video generated at: {output}")
return
config = dict(
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate",
)
# Create VAE encoder
encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
**config,
)
device = torch.device("cuda:0")
encoder = encoder.to(device, memory_format=torch.channels_last_3d)
encoder.load_state_dict(load_file(f"{model_dir}/encoder.safetensors"))
encoder.eval()
# Case 2: Image-to-video, image_dir is a single file
if image_dir is not None and os.path.isfile(image_dir) and image_dir.lower().endswith(('.jpg', '.jpeg', '.png')):
click.echo(f"Processing single image: {image_dir}")
# Load the image
from PIL import Image
import torchvision.transforms as transforms
image = Image.open(image_dir)
# Crop and resize the image to the target dimensions rather than directly resize
# Calculate crop dimensions to maintain aspect ratio
target_ratio = width / height
current_ratio = image.width / image.height
if current_ratio > target_ratio:
# Image is wider than target ratio - crop width
new_width = int(image.height * target_ratio)
x1 = (image.width - new_width) // 2
image = image.crop((x1, 0, x1 + new_width, image.height))
else:
# Image is taller than target ratio - crop height
new_height = int(image.width / target_ratio)
y1 = (image.height - new_height) // 2
image = image.crop((0, y1, image.width, y1 + new_height))
# Now resize the cropped image
transform = transforms.Compose([
transforms.Resize((height, width)),
transforms.ToTensor(),
])
image_tensor = (transform(image)* 2 - 1).unsqueeze(1).unsqueeze(0)
print("image_tensor.shape", image_tensor.shape)
image_tensor = add_fourier_features(image_tensor.to(device))
# Encode image to latent
with torch.inference_mode():
with torch.autocast("cuda", dtype=torch.bfloat16):
t0 = time.time()
encoder = encoder.to(device)
ldist = encoder(image_tensor)
image_tensor = ldist.sample()
torch.cuda.empty_cache()
encoder = encoder.to("cpu")
del ldist
# Package input for generate_video
input_image = {
"tensor": image_tensor,
"filename": os.path.basename(image_dir),
"cond_position": cond_position
}
with torch.inference_mode():
output = generate_video(
prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
data_path,
input_image,
noise_multiplier=noise_multiplier,
)
click.echo(f"Video generated at: {output}")
return
# Case 3: image_dir is a directory of images
if image_dir is not None and os.path.isdir(image_dir):
# Get all image files in the directory
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
if not image_files:
click.echo(f"No image files found in {image_dir}")
return
click.echo(f"Found {len(image_files)} image files to process")
from PIL import Image
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
])
with torch.inference_mode():
for i, image_file in enumerate(image_files):
file_path = os.path.join(image_dir, image_file)
click.echo(f"Processing file {i+1}/{len(image_files)}: {file_path}")
# Load image
image = Image.open(file_path)
# Calculate crop dimensions to maintain aspect ratio
target_ratio = width / height
current_ratio = image.width / image.height
if current_ratio > target_ratio:
# Image is wider than target ratio - crop width
new_width = int(image.height * target_ratio)
x1 = (image.width - new_width) // 2
image = image.crop((x1, 0, x1 + new_width, image.height))
else:
# Image is taller than target ratio - crop height
new_height = int(image.width / target_ratio)
y1 = (image.height - new_height) // 2
image = image.crop((0, y1, image.width, y1 + new_height))
# Now resize the cropped image
image = image.resize((width, height))
image_tensor = (transform(image)* 2 - 1).unsqueeze(1).unsqueeze(0)
print("image_tensor.shape", image_tensor.shape)
image_tensor = add_fourier_features(image_tensor.to(device))
# Encode image to latent
with torch.inference_mode():
with torch.autocast("cuda", dtype=torch.bfloat16):
t0 = time.time()
encoder = encoder.to(device)
ldist = encoder(image_tensor)
image_tensor = ldist.sample()
torch.cuda.empty_cache()
encoder = encoder.to("cpu")
del ldist
# Get corresponding prompt
img_basename = os.path.basename(file_path).split('.')[0]
prompt_file = os.path.join(prompt_dir, f"{img_basename}.txt")
if os.path.exists(prompt_file):
with open(prompt_file, 'r') as f:
file_prompt = f.read().strip()
click.echo(f"Using prompt from file: {file_prompt}")
else:
click.echo(f"Warning: Prompt file not found for {file_path}. Using default prompt.")
# Package input for generate_video
input_image = {
"tensor": image_tensor,
"filename": os.path.basename(file_path),
"cond_position": cond_position
}
output = generate_video(
file_prompt,
negative_prompt,
width,
height,
num_frames,
seed,
cfg_scale,
num_steps,
data_path,
input_image,
noise_multiplier=noise_multiplier,
)
click.echo(f"Video generated at: {output}")
return
if __name__ == "__main__":
generate_cli()