Wan2.1 / simple_app.py
keisanmono's picture
fix
8b802a2
import gradio as gr
import os
import torch
from huggingface_hub import snapshot_download
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
from wan.utils.utils import cache_video
# Download model
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./Wan2.1-T2V-1.3B"
)
# Load Model
cfg = WAN_CONFIGS['t2v-1.3B']
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir="./Wan2.1-T2V-1.3B",
device_id=0,
t5_cpu=True,
)
SIZE_OPTIONS = ["480*832", "832*480"] # 分辨率选项
FRAME_NUM_OPTIONS = [10, 20, 30, 40, 50, 60, 81] # 帧数选项
SAMPLING_STEPS_OPTIONS = [5, 10, 15, 20, 25, 30, 40, 50] # 采样步数选项
def infer(prompt, video_size, frame_num, sampling_steps, progress=gr.Progress()):
width, height = map(int, video_size.split('*'))
if progress:
progress(0, desc="Preparing...")
video = wan_t2v.generate(
prompt,
size=(width, height),
frame_num=frame_num,
sampling_steps=sampling_steps,
guide_scale=6.0,
shift=8.0,
offload_model=True, # Offload main model to save VRAM for VAE decoding
progress=progress
)
if progress:
progress(0.95, desc="Saving video to file")
save_path = "generated_video.mp4"
cache_video(
tensor=video[None],
save_file=save_path,
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1)
)
return save_path
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Wan 2.1 1.3B")
gr.Markdown(
"Enjoy this simple working UI, duplicate the space to skip the queue :)"
)
gr.HTML(
"""
<div style="display:flex;column-gap:4px;">
<a href="https://huggingface.co/spaces/fffiloni/Wan2.1?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
"""
)
prompt = gr.Textbox(label="Prompt")
video_size_dropdown = gr.Dropdown(
choices=SIZE_OPTIONS,
value=SIZE_OPTIONS[0],
label="Video Size (Resolution)",
)
frame_num_slider = gr.Slider(
minimum=FRAME_NUM_OPTIONS[0],
maximum=FRAME_NUM_OPTIONS[-1],
value=FRAME_NUM_OPTIONS[2],
step=1,
label="Frame Number (Video Length)",
)
sampling_steps_slider = gr.Slider(
minimum=SAMPLING_STEPS_OPTIONS[0],
maximum=SAMPLING_STEPS_OPTIONS[-1],
value=SAMPLING_STEPS_OPTIONS[1], # Default to 10 steps
step=1,
label="Sampling Steps (Fewer steps = Faster, Lower quality)",
)
submit_btn = gr.Button("Submit")
video_res = gr.Video(label="Generated Video")
submit_btn.click(
fn=infer,
inputs=[prompt, video_size_dropdown, frame_num_slider, sampling_steps_slider],
outputs=[video_res],
)
demo.queue().launch(share=False, show_error=True, show_api=False)