Wan22-Light / app.py
rahul7star's picture
Update app.py
fc2df73 verified
import spaces
import gradio as gr
import torch
import numpy as np
import os
import tempfile
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.export_utils import export_to_video
# Constants
LANDSCAPE_WIDTH = 832
LANDSCAPE_HEIGHT = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
T2V_FIXED_FPS = 16
MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
# Quantization config
quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer", "text_encoder"],
)
# Load pipeline
pipe = DiffusionPipeline.from_pretrained(
ckpt_id,
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
# Duration estimator
def get_duration(prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
return steps * 18 if duration_seconds <= 2.5 else steps * 25
# Inference function
@spaces.GPU(duration=get_duration)
def generate_video(prompt, height, width, negative_prompt, duration_seconds,
guidance_scale, steps, seed, randomize_seed,
progress=gr.Progress(track_tqdm=True)):
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)),
MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
current_seed = np.random.randint(0, MAX_SEED) if randomize_seed else int(seed)
output_frames_list = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=int(height),
width=int(width),
num_frames=num_frames,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
generator=torch.manual_seed(current_seed),
).frames[0]
temp_dir = tempfile.mkdtemp()
video_path = os.path.join(temp_dir, "t2v_output.mp4")
export_to_video(output_frames_list, video_path, fps=T2V_FIXED_FPS)
print(f"βœ… Video saved to: {video_path}")
return video_path # Only return video
# Gradio UI
with gr.Blocks(css="body { max-width: 100vw; overflow-x: hidden; }") as demo:
gr.Markdown("## πŸš€ Wan2.1 T2V - Text to Video Generator (Quantized, Smart Duration)")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
negative_prompt_input = gr.Textbox(label="Negative Prompt", lines=3, value="")
height_input = gr.Slider(256, 1024, step=8, value=512, label="Height")
width_input = gr.Slider(256, 1024, step=8, value=512, label="Width")
duration_input = gr.Slider(1, 10, value=2, step=0.1, label="Duration (seconds)")
steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
guidance_scale_input = gr.Slider(0.0, 20.0, step=0.5, value=7.5, label="Guidance Scale")
seed_input = gr.Number(value=42, label="Seed (optional)")
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
run_btn = gr.Button("Generate Video")
with gr.Column():
output_video = gr.Video(label="Generated Video")
ui_inputs = [
prompt_input, height_input, width_input, negative_prompt_input,
duration_input, guidance_scale_input, steps_input, seed_input,
randomize_seed_checkbox
]
run_btn.click(fn=generate_video, inputs=ui_inputs, outputs=output_video)
# Launch
demo.launch()