Spaces:
Runtime error
Runtime error
File size: 3,540 Bytes
d4a8a38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import torch
# from diffusers import MochiPipeline
from pipeline_mochi_rgba import MochiPipeline
from diffusers.utils import export_to_video
import argparse
from rgba_utils import *
import numpy as np
def main(args):
# 1. load pipeline
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16).to("cuda")
pipe.enable_vae_tiling()
# 2. define prompt and arguments
pipeline_args = {
"prompt": args.prompt,
"guidance_scale": args.guidance_scale,
"num_inference_steps": args.num_inference_steps,
"height": args.height,
"width": args.width,
"num_frames": args.num_frames,
"max_sequence_length": 256,
"output_type": "latent",
}
# 3. prepare rgbx utils
prepare_for_rgba_inference(
pipe.transformer,
device="cuda",
dtype=torch.bfloat16,
)
if args.lora_path is not None:
checkpoint = torch.load(args.lora_path, map_location="cpu")
processor_state_dict = checkpoint["state_dict"]
load_processor_state_dict(pipe.transformer, processor_state_dict)
# 4. inference
generator = torch.manual_seed(args.seed) if args.seed else None
frames_latents = pipe(**pipeline_args, generator=generator).frames
frames_latents_rgb, frames_latents_alpha = frames_latents.chunk(2, dim=2)
frames_rgb = decode_latents(pipe, frames_latents_rgb)
frames_alpha = decode_latents(pipe, frames_latents_alpha)
pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
premultiplied_rgb = frames_rgb * frames_alpha_pooled
if os.path.exists(args.output_path) == False:
os.makedirs(args.output_path)
export_to_video(premultiplied_rgb[0], os.path.join(args.output_path, "rgb.mp4"), fps=args.fps)
export_to_video(frames_alpha_pooled[0], os.path.join(args.output_path, "alpha.mp4"), fps=args.fps)
export_to_video(frames_rgb[0], os.path.join(args.output_path, "original_rgb.mp4"), fps=args.fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a video from a text prompt")
parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
parser.add_argument(
"--model_path", type=str, default="genmo/mochi-1-preview", help="Path of the pre-trained model use"
)
parser.add_argument("--output_path", type=str, default="./output", help="The path save generated video")
parser.add_argument("--guidance_scale", type=float, default=6, help="The scale for classifier-free guidance")
parser.add_argument("--num_inference_steps", type=int, default=64, help="Inference steps")
parser.add_argument("--num_frames", type=int, default=79, help="Number of steps for the inference process")
parser.add_argument("--width", type=int, default=848, help="Number of steps for the inference process")
parser.add_argument("--height", type=int, default=480, help="Number of steps for the inference process")
parser.add_argument("--fps", type=int, default=30, help="Number of steps for the inference process")
parser.add_argument("--seed", type=int, default=None, help="The seed for reproducibility")
args = parser.parse_args()
main(args)
|