Spaces:
Build error
Build error
import gradio as gr | |
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
import wan | |
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS | |
from wan.utils.utils import cache_video | |
# Download model | |
snapshot_download( | |
repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./Wan2.1-T2V-1.3B" | |
) | |
# Load Model | |
cfg = WAN_CONFIGS['t2v-1.3B'] | |
wan_t2v = wan.WanT2V( | |
config=cfg, | |
checkpoint_dir="./Wan2.1-T2V-1.3B", | |
device_id=0, | |
t5_cpu=True, | |
) | |
SIZE_OPTIONS = ["480*832", "832*480"] # 分辨率选项 | |
FRAME_NUM_OPTIONS = [10, 20, 30, 40, 50, 60, 81] # 帧数选项 | |
SAMPLING_STEPS_OPTIONS = [5, 10, 15, 20, 25, 30, 40, 50] # 采样步数选项 | |
def infer(prompt, video_size, frame_num, sampling_steps, progress=gr.Progress()): | |
width, height = map(int, video_size.split('*')) | |
if progress: | |
progress(0, desc="Preparing...") | |
video = wan_t2v.generate( | |
prompt, | |
size=(width, height), | |
frame_num=frame_num, | |
sampling_steps=sampling_steps, | |
guide_scale=6.0, | |
shift=8.0, | |
offload_model=True, # Offload main model to save VRAM for VAE decoding | |
progress=progress | |
) | |
if progress: | |
progress(0.95, desc="Saving video to file") | |
save_path = "generated_video.mp4" | |
cache_video( | |
tensor=video[None], | |
save_file=save_path, | |
fps=cfg.sample_fps, | |
nrow=1, | |
normalize=True, | |
value_range=(-1, 1) | |
) | |
return save_path | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Wan 2.1 1.3B") | |
gr.Markdown( | |
"Enjoy this simple working UI, duplicate the space to skip the queue :)" | |
) | |
gr.HTML( | |
""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href="https://huggingface.co/spaces/fffiloni/Wan2.1?duplicate=true"> | |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
</a> | |
<a href="https://huggingface.co/fffiloni"> | |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF"> | |
</a> | |
</div> | |
""" | |
) | |
prompt = gr.Textbox(label="Prompt") | |
video_size_dropdown = gr.Dropdown( | |
choices=SIZE_OPTIONS, | |
value=SIZE_OPTIONS[0], | |
label="Video Size (Resolution)", | |
) | |
frame_num_slider = gr.Slider( | |
minimum=FRAME_NUM_OPTIONS[0], | |
maximum=FRAME_NUM_OPTIONS[-1], | |
value=FRAME_NUM_OPTIONS[2], | |
step=1, | |
label="Frame Number (Video Length)", | |
) | |
sampling_steps_slider = gr.Slider( | |
minimum=SAMPLING_STEPS_OPTIONS[0], | |
maximum=SAMPLING_STEPS_OPTIONS[-1], | |
value=SAMPLING_STEPS_OPTIONS[1], # Default to 10 steps | |
step=1, | |
label="Sampling Steps (Fewer steps = Faster, Lower quality)", | |
) | |
submit_btn = gr.Button("Submit") | |
video_res = gr.Video(label="Generated Video") | |
submit_btn.click( | |
fn=infer, | |
inputs=[prompt, video_size_dropdown, frame_num_slider, sampling_steps_slider], | |
outputs=[video_res], | |
) | |
demo.queue().launch(share=False, show_error=True, show_api=False) |