File size: 3,345 Bytes
4e424ea
0d761fa
156d055
4e424ea
 
156d055
 
 
 
a847318
4e424ea
a847318
4e424ea
 
156d055
 
 
 
 
 
 
 
 
42e8ce7
621b71d
af68d8f
a847318
 
8b802a2
156d055
 
8b802a2
 
 
156d055
a847318
156d055
 
 
 
 
8b802a2
156d055
 
40ee691
d794d5f
 
156d055
 
 
 
 
 
 
 
a847318
156d055
 
4e424ea
a847318
4e424ea
 
11d8d99
a847318
 
 
 
 
e86049c
 
 
 
 
 
 
 
a847318
 
1986a1b
 
 
 
af68d8f
1986a1b
 
 
af68d8f
 
 
 
1986a1b
 
af68d8f
 
 
 
 
 
 
a847318
1986a1b
 
4e424ea
 
a847318
af68d8f
a847318
4e424ea
 
1986a1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import os
import torch
from huggingface_hub import snapshot_download

import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
from wan.utils.utils import cache_video

# Download model
snapshot_download(
    repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./Wan2.1-T2V-1.3B"
)

# Load Model
cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2v = wan.WanT2V(
    config=cfg,
    checkpoint_dir="./Wan2.1-T2V-1.3B",
    device_id=0,
    t5_cpu=True,
)

SIZE_OPTIONS = ["480*832", "832*480"]  # 分辨率选项
FRAME_NUM_OPTIONS = [10, 20, 30, 40, 50, 60, 81]  # 帧数选项
SAMPLING_STEPS_OPTIONS = [5, 10, 15, 20, 25, 30, 40, 50] # 采样步数选项


def infer(prompt, video_size, frame_num, sampling_steps, progress=gr.Progress()):
    width, height = map(int, video_size.split('*'))
    
    if progress:
        progress(0, desc="Preparing...")

    video = wan_t2v.generate(
        prompt,
        size=(width, height),
        frame_num=frame_num,
        sampling_steps=sampling_steps,
        guide_scale=6.0,
        shift=8.0,
        offload_model=True, # Offload main model to save VRAM for VAE decoding
        progress=progress
    )
    
    if progress:
        progress(0.95, desc="Saving video to file")
    save_path = "generated_video.mp4"
    cache_video(
        tensor=video[None],
        save_file=save_path,
        fps=cfg.sample_fps,
        nrow=1,
        normalize=True,
        value_range=(-1, 1)
    )
    
    return save_path


with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# Wan 2.1 1.3B")
        gr.Markdown(
            "Enjoy this simple working UI, duplicate the space to skip the queue :)"
        )
        gr.HTML(
            """
        <div style="display:flex;column-gap:4px;">
            <a href="https://huggingface.co/spaces/fffiloni/Wan2.1?duplicate=true">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
            </a>
            <a href="https://huggingface.co/fffiloni">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
            </a>
        </div>
        """
        )
        prompt = gr.Textbox(label="Prompt")

        video_size_dropdown = gr.Dropdown(
            choices=SIZE_OPTIONS,
            value=SIZE_OPTIONS[0],
            label="Video Size (Resolution)",
        )
        frame_num_slider = gr.Slider(
            minimum=FRAME_NUM_OPTIONS[0],
            maximum=FRAME_NUM_OPTIONS[-1],
            value=FRAME_NUM_OPTIONS[2],
            step=1,
            label="Frame Number (Video Length)",
        )
        sampling_steps_slider = gr.Slider(
            minimum=SAMPLING_STEPS_OPTIONS[0],
            maximum=SAMPLING_STEPS_OPTIONS[-1],
            value=SAMPLING_STEPS_OPTIONS[1], # Default to 10 steps
            step=1,
            label="Sampling Steps (Fewer steps = Faster, Lower quality)",
        )

        submit_btn = gr.Button("Submit")
        video_res = gr.Video(label="Generated Video")

    submit_btn.click(
        fn=infer,
        inputs=[prompt, video_size_dropdown, frame_num_slider, sampling_steps_slider],
        outputs=[video_res],
    )

demo.queue().launch(share=False, show_error=True, show_api=False)