Spaces:
Paused
Paused
import spaces | |
import gradio as gr | |
import torch | |
import numpy as np | |
import os | |
import tempfile | |
from diffusers import DiffusionPipeline | |
from diffusers.quantizers import PipelineQuantizationConfig | |
from diffusers.utils.export_utils import export_to_video | |
# Constants | |
LANDSCAPE_WIDTH = 832 | |
LANDSCAPE_HEIGHT = 480 | |
MAX_SEED = np.iinfo(np.int32).max | |
FIXED_FPS = 16 | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 81 | |
T2V_FIXED_FPS = 16 | |
MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1) | |
MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1) | |
# Checkpoint ID | |
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
# Quantization config | |
quant_config = PipelineQuantizationConfig( | |
quant_backend="bitsandbytes_4bit", | |
quant_kwargs={ | |
"load_in_4bit": True, | |
"bnb_4bit_quant_type": "nf4", | |
"bnb_4bit_compute_dtype": torch.bfloat16, | |
}, | |
components_to_quantize=["transformer", "text_encoder"], | |
) | |
# Load pipeline | |
pipe = DiffusionPipeline.from_pretrained( | |
ckpt_id, | |
quantization_config=quant_config, | |
torch_dtype=torch.bfloat16, | |
) | |
pipe.enable_model_cpu_offload() | |
# Duration estimator | |
def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): | |
return steps * 18 if duration_seconds <= 2.5 else steps * 25 | |
# Inference function | |
def generate_video(prompt, height, width, negative_prompt, duration_seconds, | |
guidance_scale, steps, seed, randomize_seed, | |
progress=gr.Progress(track_tqdm=True)): | |
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), | |
MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
current_seed = np.random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
output_frames_list = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=int(height), | |
width=int(width), | |
num_frames=num_frames, | |
guidance_scale=float(guidance_scale), | |
num_inference_steps=int(steps), | |
generator=torch.manual_seed(current_seed), | |
).frames[0] | |
temp_dir = tempfile.mkdtemp() | |
video_path = os.path.join(temp_dir, "t2v_output.mp4") | |
export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS) | |
print(f"β Video saved to: {video_path}") | |
return video_path # Only return video | |
# Gradio UI | |
with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo: | |
gr.Markdown("## π Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.") | |
negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="") | |
height_input = gr.Slider(256, 1024, step=8, value=512, label="Height") | |
width_input = gr.Slider(256, 1024, step=8, value=512, label="Width") | |
duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)") | |
steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps") | |
guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale") | |
seed_input = gr.Number(value=42, label="Seed (optional)") | |
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) | |
run_btn = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video") | |
ui_inputs = [ | |
prompt_input, height_input, width_input, negative_prompt_input, | |
duration_input, guidance_scale_input, steps_input, seed_input, | |
randomize_seed_checkbox | |
] | |
run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video) | |
# Launch | |
demo.launch() | |