File size: 2,663 Bytes
789bafb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers import PipelineQuantizationConfig
import imageio

# Checkpoint ID
ckpt_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"

# Configure quantization (bitsandbytes 4-bit)
quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": torch.bfloat16
    },
    components_to_quantize=["transformer", "text_encoder"]
)

# Load pipeline with quantization
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id,
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16
).to("cuda")

# Optimize memory and performance
pipe.enable_model_cpu_offload()
torch._dynamo.config.recompile_limit = 1000
torch._dynamo.config.capture_dynamic_output_shape_ops = True
pipe.transformer.compile()

# Duration function
def get_duration(prompt, height, width, 
                 negative_prompt, duration_seconds,
                 guidance_scale, steps,
                 seed, randomize_seed):
    if steps > 4 and duration_seconds > 2:
        return 90
    elif steps > 4 or duration_seconds > 2:
        return 75
    else:
        return 60

# Gradio inference function (no @spaces.GPU decorator) to avoid progress ContextVar error
def generate_video(prompt, seed, steps, duration_seconds):
    generator = torch.manual_seed(seed) if seed else None
    fps = 8
    num_frames = duration_seconds * fps if duration_seconds else 16

    video_frames = pipe(
        prompt=prompt,
        num_frames=num_frames,
        generator=generator,
        num_inference_steps=steps
    ).frames[0]

    out_path = "output.gif"
    imageio.mimsave(out_path, video_frames, fps=fps)
    return out_path

# Build Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🚀 Wan2.1 T2V - Text to Video Generator (Quantized, Dynamic Duration)")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", lines=3, value="A futuristic cityscape with flying cars and neon lights.")
            seed_input = gr.Number(value=42, label="Seed (optional)")
            steps_input = gr.Slider(1, 50, value=20, step=1, label="Inference Steps")
            duration_input = gr.Slider(1, 10, value=2, step=1, label="Video Duration (seconds)")
            run_btn = gr.Button("Generate Video")
        with gr.Column():
            output_video = gr.Video(label="Generated Video")

    run_btn.click(fn=generate_video, inputs=[prompt_input, seed_input, steps_input, duration_input], outputs=output_video)

# Launch demo
demo.launch()