import spaces import gradio as gr import torch import numpy as np import os import tempfile from diffusers import DiffusionPipeline from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils.export_utils import export_to_video # Constants LANDSCAPE_WIDTH = 832 LANDSCAPE_HEIGHT = 480 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 81 T2V_FIXED_FPS = 16 MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1) MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1) # Checkpoint ID ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Quantization config quant_config = PipelineQuantizationConfig( quant_backend="bitsandbytes_4bit", quant_kwargs={ "load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16, }, components_to_quantize=["transformer", "text_encoder"], ) # Load pipeline pipe = DiffusionPipeline.from_pretrained( ckpt_id, quantization_config=quant_config, torch_dtype=torch.bfloat16, ) pipe.enable_model_cpu_offload() # Duration estimator def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): return steps * 18 if duration_seconds <= 2.5 else steps * 25 # Inference function @spaces.GPU(duration=get_duration) def generate_video(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)): num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) current_seed = np.random.randint(0, MAX_SEED) if randomize_seed else int(seed) output_frames_list = pipe( prompt=prompt, negative_prompt=negative_prompt, height=int(height), width=int(width), num_frames=num_frames, guidance_scale=float(guidance_scale), num_inference_steps=int(steps), generator=torch.manual_seed(current_seed), ).frames[0] temp_dir = tempfile.mkdtemp() video_path = os.path.join(temp_dir, "t2v_output.mp4") export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS) print(f"✅ Video saved to: {video_path}") return video_path # Only return video # Gradio UI with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo: gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.") negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="") height_input = gr.Slider(256, 1024, step=8, value=512, label="Height") width_input = gr.Slider(256, 1024, step=8, value=512, label="Width") duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)") steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps") guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale") seed_input = gr.Number(value=42, label="Seed (optional)") randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) run_btn = gr.Button("Generate Video") with gr.Column(): output_video = gr.Video(label="Generated Video") ui_inputs = [ prompt_input, height_input, width_input, negative_prompt_input, duration_input, guidance_scale_input, steps_input, seed_input, randomize_seed_checkbox ] run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video) # Launch demo.launch()