File size: 7,647 Bytes
dc155d4
 
 
 
 
 
 
fe70d6a
dc155d4
 
 
 
 
 
 
048bf77
dc155d4
 
 
048bf77
dc155d4
5985e8a
 
dc155d4
 
048bf77
dc155d4
048bf77
 
 
 
dc155d4
fe70d6a
d4fdc43
a40e3c4
 
 
 
 
 
 
 
 
 
fe70d6a
dc155d4
 
 
048bf77
 
 
 
 
736f1ae
dc155d4
 
 
 
 
 
 
 
fe70d6a
048bf77
dc155d4
 
 
 
 
 
048bf77
dc155d4
 
 
 
 
 
 
 
5985e8a
dc155d4
 
5985e8a
 
 
dc155d4
 
 
 
 
048bf77
dc155d4
 
048bf77
 
dc155d4
 
 
 
 
 
 
048bf77
 
dc155d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
048bf77
dc155d4
 
048bf77
dc155d4
 
5985e8a
dc155d4
 
5985e8a
dc155d4
 
5985e8a
 
 
dc155d4
048bf77
dc155d4
5985e8a
dc155d4
5985e8a
dc155d4
5985e8a
dc155d4
 
5985e8a
 
dc155d4
 
 
5985e8a
dc155d4
 
 
 
 
5985e8a
 
 
dc155d4
5985e8a
dc155d4
5985e8a
dc155d4
5985e8a
 
 
048bf77
dc155d4
5985e8a
dc155d4
fe70d6a
 
 
5985e8a
fe70d6a
 
5985e8a
fe70d6a
dc155d4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# PyTorch 2.8 (temporary hack)
import os
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')

# Actual demo code
import spaces
import torch
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
import gc
from optimization import optimize_pipeline_


MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers"

LANDSCAPE_WIDTH = 1024
LANDSCAPE_HEIGHT = 1024
MAX_SEED = np.iinfo(np.int32).max

FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81

MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)

vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(MODEL_ID,
    transformer=WanTransformer3DModel.from_pretrained('linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
        subfolder='transformer',
        torch_dtype=torch.bfloat16,
        device_map='cuda',
    ),
    transformer_2=WanTransformer3DModel.from_pretrained('linoyts/Wan2.2-T2V-A14B-Diffusers-BF16',
        subfolder='transformer_2',
        torch_dtype=torch.bfloat16,
        device_map='cuda',
    ),
    vae=vae,
    torch_dtype=torch.bfloat16,
).to('cuda')


for i in range(3): 
    gc.collect()
    torch.cuda.synchronize() 
    torch.cuda.empty_cache()

optimize_pipeline_(pipe,
    prompt='prompt',
    height=LANDSCAPE_HEIGHT,
    width=LANDSCAPE_WIDTH,
    num_frames=MAX_FRAMES_MODEL,
)


default_prompt_i2v = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"


def get_duration(
    prompt,
    negative_prompt,
    guidance_scale,
    guidance_scale_2,
    steps,
    seed,
    randomize_seed,
    progress,
):
    return steps * 15

@spaces.GPU(duration=get_duration)
def generate_image(
    prompt,
    negative_prompt=default_negative_prompt,
    guidance_scale = 3.5,
    guidance_scale_2 = 4,
    steps = 27,
    seed = 42,
    randomize_seed = False,
    progress=gr.Progress(track_tqdm=True),
):
    """
    Generate a video from an input image using the Wan 2.2 14B I2V model with Phantom LoRA.
    
    This function takes an input image and generates a video animation based on the provided
    prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Phantom LoRA
    for fast generation in 6-8 steps.
    
    Args:
        prompt (str): Text prompt describing the desired animation or motion.
        negative_prompt (str, optional): Negative prompt to avoid unwanted elements. 
            Defaults to default_negative_prompt (contains unwanted visual artifacts).
        guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
            Defaults to 1.0. Range: 0.0-20.0.
        guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence.
            Defaults to 1.0. Range: 0.0-20.0.
        steps (int, optional): Number of inference steps. More steps = higher quality but slower.
            Defaults to 4. Range: 1-30.
        seed (int, optional): Random seed for reproducible results. Defaults to 42.
            Range: 0 to MAX_SEED (2147483647).
        randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
            Defaults to False.
        progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
    
    Returns:
        tuple: A tuple containing:
            - video_path (str): Path to the generated video file (.mp4)
            - current_seed (int): The seed used for generation (useful when randomize_seed=True)
    
    Raises:
        gr.Error: If input_image is None (no image uploaded).
    
    Note:
        - The function automatically resizes the input image to the target dimensions
        - Frame count is calculated as duration_seconds * FIXED_FPS (24)
        - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
        - The function uses GPU acceleration via the @spaces.GPU decorator
        - Generation time varies based on steps and duration (see get_duration function)
    """
    
   
    current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)

    out_img = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=1024,
        width=1024,
        num_frames=1,
        guidance_scale=float(guidance_scale),
        guidance_scale_2=float(guidance_scale_2),
        num_inference_steps=int(steps),
        output_type="pil",
        generator=torch.Generator(device="cuda").manual_seed(current_seed),
    ).frames[0][0]

    return out_img, current_seed

with gr.Blocks() as demo:
    gr.Markdown("# Wan 2.2 T2I (14B)")
    #gr.Markdown("run Wan 2.2 in just 6-8 steps, with [FusionX Phantom LoRA by DeeJayT](https://huggingface.co/vrgamedevgirl84/Wan14BT2VFusioniX/tree/main/FusionX_LoRa), compatible with 🧨 diffusers")
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
            #duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=MAX_DURATION, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
            
            with gr.Accordion("Advanced Settings", open=False):
                negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
                seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
                randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
                steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=27, label="Inference Steps") 
                guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=3.5, label="Guidance Scale - high noise stage")
                guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=4, label="Guidance Scale 2 - low noise stage")

            generate_button = gr.Button("Generate Image", variant="primary")
        with gr.Column():
            img_output = gr.Image(label="Generated Image", interactive=False)
    
    ui_inputs = [ 
        prompt_input,
        negative_prompt_input,
        guidance_scale_input, guidance_scale_2_input, steps_slider, seed_input, randomize_seed_checkbox
    ]
    generate_button.click(fn=generate_image, inputs=ui_inputs, outputs=[img_output, seed_input])

    gr.Examples(
        examples=[ 
            [
                "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
            ],
        ],
        inputs=[prompt_input], outputs=[img_output, seed_input], fn=generate_image, cache_examples="lazy"
    )

if __name__ == "__main__":
    demo.queue().launch(mcp_server=True)