Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline | |
from diffusers.quantizers import PipelineQuantizationConfig | |
import imageio | |
# Checkpoint ID | |
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
# Configure quantization (bitsandbytes 4-bit) | |
quant_config = PipelineQuantizationConfig( | |
quant_backend="bitsandbytes_4bit", | |
quant_kwargs={ | |
"load_in_4bit": True, | |
"bnb_4bit_quant_type": "nf4", | |
"bnb_4bit_compute_dtype": torch.bfloat16 | |
}, | |
components_to_quantize=["transformer", "text_encoder"] | |
) | |
# Load pipeline with quantization | |
pipe = DiffusionPipeline.from_pretrained( | |
ckpt_id, | |
quantization_config=quant_config, | |
torch_dtype=torch.bfloat16 | |
).to("cuda") | |
# Optimize memory and performance | |
pipe.enable_model_cpu_offload() | |
torch._dynamo.config.recompile_limit = 1000 | |
torch._dynamo.config.capture_dynamic_output_shape_ops = True | |
pipe.transformer.compile() | |
# Duration function | |
def get_duration(prompt, height, width, | |
negative_prompt, duration_seconds, | |
guidance_scale, steps, | |
seed, randomize_seed): | |
if steps > 4 and duration_seconds > 2: | |
return 90 | |
elif steps > 4 or duration_seconds > 2: | |
return 75 | |
else: | |
return 60 | |
# Gradio inference function (no @spaces.GPU decorator) to avoid progress ContextVar error | |
def generate_video(prompt, seed, steps, duration_seconds): | |
generator = torch.manual_seed(seed) if seed else None | |
fps = 8 | |
num_frames = duration_seconds * fps if duration_seconds else 16 | |
video_frames = pipe( | |
prompt=prompt, | |
num_frames=num_frames, | |
generator=generator, | |
num_inference_steps=steps | |
).frames[0] | |
out_path = "output.gif" | |
imageio.mimsave(out_path, video_frames, fps=fps) | |
return out_path | |
# Build Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## π Wan2.1 T2V - Text to Video Generator (Quantized, Dynamic Duration)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.") | |
seed_input = gr.Number(value=42, label="Seed (optional)") | |
steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps") | |
duration_input = gr.Slider(1, 10, value=2, step=1, label="Video Duration (seconds)") | |
run_btn = gr.Button("Generate Video") | |
with gr.Column(): | |
output_video = gr.Video(label="Generated Video") | |
run_btn.click(fn=generate_video, inputs=[prompt_input, seed_input, steps_input, duration_input], outputs=output_video) | |
# Launch demo | |
demo.launch() | |