|
import os |
|
import shutil |
|
import random |
|
import sys |
|
import tempfile |
|
from typing import Sequence, Mapping, Any, Union |
|
|
|
import spaces |
|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
from huggingface_hub import hf_hub_download |
|
from comfy import model_management |
|
|
|
def hf_hub_download_local(repo_id, filename, local_dir, **kwargs): |
|
downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs) |
|
os.makedirs(local_dir, exist_ok=True) |
|
base_filename = os.path.basename(filename) |
|
target_path = os.path.join(local_dir, base_filename) |
|
|
|
if os.path.exists(target_path) or os.path.islink(target_path): |
|
os.remove(target_path) |
|
|
|
os.symlink(downloaded_path, target_path) |
|
return target_path |
|
|
|
|
|
print("Downloading models from Hugging Face Hub...") |
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders") |
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") |
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet") |
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae") |
|
hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision") |
|
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras") |
|
hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras") |
|
print("Downloads complete.") |
|
|
|
|
|
def calculate_video_dimensions(width, height, max_size=832, min_size=480): |
|
""" |
|
Calculate video dimensions based on input image size. |
|
Larger dimension becomes max_size, smaller becomes proportional. |
|
If square, use min_size x min_size. |
|
Results are rounded to nearest multiple of 16. |
|
""" |
|
|
|
if width == height: |
|
video_width = min_size |
|
video_height = min_size |
|
else: |
|
|
|
aspect_ratio = width / height |
|
|
|
if width > height: |
|
|
|
video_width = max_size |
|
video_height = int(max_size / aspect_ratio) |
|
else: |
|
|
|
video_height = max_size |
|
video_width = int(max_size * aspect_ratio) |
|
|
|
|
|
video_width = round(video_width / 16) * 16 |
|
video_height = round(video_height / 16) * 16 |
|
|
|
|
|
video_width = max(video_width, 16) |
|
video_height = max(video_height, 16) |
|
|
|
return video_width, video_height |
|
|
|
def resize_and_crop_to_match(target_image, reference_image): |
|
""" |
|
Resize and center crop target_image to match reference_image dimensions. |
|
""" |
|
ref_width, ref_height = reference_image.size |
|
target_width, target_height = target_image.size |
|
|
|
|
|
scale = max(ref_width / target_width, ref_height / target_height) |
|
|
|
|
|
new_width = int(target_width * scale) |
|
new_height = int(target_height * scale) |
|
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
|
|
left = (new_width - ref_width) // 2 |
|
top = (new_height - ref_height) // 2 |
|
right = left + ref_width |
|
bottom = top + ref_height |
|
|
|
cropped = resized.crop((left, top, right, bottom)) |
|
return cropped |
|
|
|
|
|
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: |
|
"""Returns the value at the given index of a sequence or mapping. |
|
|
|
If the object is a sequence (like list or string), returns the value at the given index. |
|
If the object is a mapping (like a dictionary), returns the value at the index-th key. |
|
|
|
Some return a dictionary, in these cases, we look for the "results" key |
|
|
|
Args: |
|
obj (Union[Sequence, Mapping]): The object to retrieve the value from. |
|
index (int): The index of the value to retrieve. |
|
|
|
Returns: |
|
Any: The value at the given index. |
|
|
|
Raises: |
|
IndexError: If the index is out of bounds for the object and the object is not a mapping. |
|
""" |
|
try: |
|
return obj[index] |
|
except KeyError: |
|
|
|
if isinstance(obj, Mapping) and "result" in obj: |
|
return obj["result"][index] |
|
raise |
|
|
|
def find_path(name: str, path: str = None) -> str: |
|
""" |
|
Recursively looks at parent folders starting from the given path until it finds the given name. |
|
Returns the path as a Path object if found, or None otherwise. |
|
""" |
|
if path is None: |
|
path = os.getcwd() |
|
|
|
if name in os.listdir(path): |
|
path_name = os.path.join(path, name) |
|
print(f"'{name}' found: {path_name}") |
|
return path_name |
|
|
|
parent_directory = os.path.dirname(path) |
|
if parent_directory == path: |
|
return None |
|
|
|
return find_path(name, parent_directory) |
|
|
|
|
|
def add_comfyui_directory_to_sys_path() -> None: |
|
""" |
|
Add 'ComfyUI' to the sys.path |
|
""" |
|
comfyui_path = find_path("ComfyUI") |
|
if comfyui_path is not None and os.path.isdir(comfyui_path): |
|
sys.path.append(comfyui_path) |
|
print(f"'{comfyui_path}' added to sys.path") |
|
else: |
|
print("Could not find ComfyUI directory. Please run from a parent folder of ComfyUI.") |
|
|
|
def add_extra_model_paths() -> None: |
|
""" |
|
Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path. |
|
""" |
|
try: |
|
from main import load_extra_path_config |
|
except ImportError: |
|
print( |
|
"Could not import load_extra_path_config from main.py. This might be okay if you don't use it." |
|
) |
|
return |
|
|
|
extra_model_paths = find_path("extra_model_paths.yaml") |
|
if extra_model_paths is not None: |
|
load_extra_path_config(extra_model_paths) |
|
else: |
|
print("Could not find an optional 'extra_model_paths.yaml' config file.") |
|
|
|
def import_custom_nodes() -> None: |
|
"""Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS |
|
This function sets up a new asyncio event loop, initializes the PromptServer, |
|
creates a PromptQueue, and initializes the custom nodes. |
|
""" |
|
import asyncio |
|
import execution |
|
from nodes import init_extra_nodes |
|
import server |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
server_instance = server.PromptServer(loop) |
|
execution.PromptQueue(server_instance) |
|
loop.run_until_complete(init_extra_nodes(init_custom_nodes=True)) |
|
|
|
|
|
|
|
MODELS_AND_NODES = {} |
|
|
|
print("Setting up ComfyUI paths...") |
|
add_comfyui_directory_to_sys_path() |
|
add_extra_model_paths() |
|
|
|
print("Importing custom nodes...") |
|
import_custom_nodes() |
|
|
|
|
|
from nodes import NODE_CLASS_MAPPINGS |
|
global folder_paths |
|
import folder_paths |
|
|
|
print("Loading models into memory. This may take a few minutes...") |
|
|
|
|
|
cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]() |
|
MODELS_AND_NODES["clip"] = cliploader.load_clip( |
|
clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan", device="cpu" |
|
) |
|
|
|
unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]() |
|
unet_low_noise = unetloader.load_unet( |
|
unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", |
|
weight_dtype="default", |
|
) |
|
unet_high_noise = unetloader.load_unet( |
|
unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", |
|
weight_dtype="default", |
|
) |
|
|
|
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() |
|
MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors") |
|
|
|
|
|
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() |
|
MODELS_AND_NODES["model_low_noise"] = loraloadermodelonly.load_lora_model_only( |
|
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", |
|
strength_model=0.8, |
|
model=get_value_at_index(unet_low_noise, 0), |
|
) |
|
MODELS_AND_NODES["model_high_noise"] = loraloadermodelonly.load_lora_model_only( |
|
lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", |
|
strength_model=0.8, |
|
model=get_value_at_index(unet_high_noise, 0), |
|
) |
|
|
|
|
|
clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]() |
|
MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip( |
|
clip_name="clip_vision_h.safetensors" |
|
) |
|
|
|
|
|
MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() |
|
MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]() |
|
MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]() |
|
MODELS_AND_NODES["ModelSamplingSD3"] = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]() |
|
MODELS_AND_NODES["PathchSageAttentionKJ"] = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]() |
|
MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]() |
|
MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]() |
|
MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]() |
|
MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]() |
|
MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]() |
|
|
|
print("Pre-loading main models onto GPU...") |
|
model_loaders = [ |
|
MODELS_AND_NODES["clip"], |
|
MODELS_AND_NODES["vae"], |
|
MODELS_AND_NODES["model_low_noise"], |
|
MODELS_AND_NODES["model_high_noise"], |
|
MODELS_AND_NODES["clip_vision"], |
|
] |
|
model_management.load_models_gpu([ |
|
loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders |
|
]) |
|
print("All models loaded successfully!") |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def generate_video( |
|
start_image_pil, |
|
end_image_pil, |
|
prompt, |
|
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", |
|
duration=33, |
|
progress=gr.Progress(track_tqdm=True) |
|
): |
|
""" |
|
The main function to generate a video based on user inputs. |
|
This function is called every time the user clicks the 'Generate' button. |
|
""" |
|
FPS = 16 |
|
|
|
|
|
|
|
processed_start_image = start_image_pil.copy() |
|
processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil) |
|
|
|
|
|
video_width, video_height = calculate_video_dimensions( |
|
processed_start_image.width, |
|
processed_start_image.height |
|
) |
|
|
|
print(f"Input image size: {processed_start_image.width}x{processed_start_image.height}") |
|
print(f"Video dimensions: {video_width}x{video_height}") |
|
|
|
clip = MODELS_AND_NODES["clip"] |
|
vae = MODELS_AND_NODES["vae"] |
|
model_low_noise = MODELS_AND_NODES["model_low_noise"] |
|
model_high_noise = MODELS_AND_NODES["model_high_noise"] |
|
clip_vision = MODELS_AND_NODES["clip_vision"] |
|
|
|
cliptextencode = MODELS_AND_NODES["CLIPTextEncode"] |
|
loadimage = MODELS_AND_NODES["LoadImage"] |
|
clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"] |
|
modelsamplingsd3 = MODELS_AND_NODES["ModelSamplingSD3"] |
|
pathchsageattentionkj = MODELS_AND_NODES["PathchSageAttentionKJ"] |
|
wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"] |
|
ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"] |
|
vaedecode = MODELS_AND_NODES["VAEDecode"] |
|
createvideo = MODELS_AND_NODES["CreateVideo"] |
|
savevideo = MODELS_AND_NODES["SaveVideo"] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as start_file, \ |
|
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as end_file: |
|
processed_start_image.save(start_file.name) |
|
processed_end_image.save(end_file.name) |
|
start_image_path = start_file.name |
|
end_image_path = end_file.name |
|
|
|
with torch.inference_mode(): |
|
progress(0.1, desc="Encoding text and images...") |
|
|
|
positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0)) |
|
negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0)) |
|
|
|
start_image_loaded = loadimage.load_image(image=start_image_path) |
|
end_image_loaded = loadimage.load_image(image=end_image_path) |
|
|
|
clip_vision_encoded_start = clipvisionencode.encode( |
|
crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0) |
|
) |
|
clip_vision_encoded_end = clipvisionencode.encode( |
|
crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0) |
|
) |
|
|
|
progress(0.2, desc="Preparing initial latents...") |
|
initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED( |
|
width=video_width, height=video_height, length=duration, batch_size=1, |
|
positive=get_value_at_index(positive_conditioning, 0), |
|
negative=get_value_at_index(negative_conditioning, 0), |
|
vae=get_value_at_index(vae, 0), |
|
clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0), |
|
clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0), |
|
start_image=get_value_at_index(start_image_loaded, 0), |
|
end_image=get_value_at_index(end_image_loaded, 0), |
|
) |
|
|
|
progress(0.3, desc="Patching models...") |
|
model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_noise, 0)) |
|
model_low_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0)) |
|
|
|
model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_noise, 0)) |
|
model_high_final = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0)) |
|
|
|
progress(0.5, desc="Running KSampler (Step 1/2)...") |
|
latent_step1 = ksampleradvanced.sample( |
|
add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, |
|
sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4, |
|
return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0), |
|
positive=get_value_at_index(initial_latents, 0), |
|
negative=get_value_at_index(initial_latents, 1), |
|
latent_image=get_value_at_index(initial_latents, 2), |
|
) |
|
|
|
progress(0.7, desc="Running KSampler (Step 2/2)...") |
|
latent_step2 = ksampleradvanced.sample( |
|
add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1, |
|
sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000, |
|
return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0), |
|
positive=get_value_at_index(initial_latents, 0), |
|
negative=get_value_at_index(initial_latents, 1), |
|
latent_image=get_value_at_index(latent_step1, 0), |
|
) |
|
|
|
progress(0.8, desc="Decoding VAE...") |
|
decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0)) |
|
|
|
progress(0.9, desc="Creating and saving video...") |
|
video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0)) |
|
|
|
|
|
save_result = savevideo.save_video( |
|
filename_prefix="GradioVideo", format="mp4", codec="h264", |
|
video=get_value_at_index(video_data, 0), |
|
) |
|
|
|
progress(1.0, desc="Done!") |
|
return f"output/{save_result['ui']['images'][0]['filename']}" |
|
|
|
|
|
|
|
css = ''' |
|
.fillable{max-width: 1100px !important} |
|
.dark .progress-text {color: white} |
|
''' |
|
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: |
|
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") |
|
gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
with gr.Row(): |
|
start_image = gr.Image(type="pil", label="Start Frame") |
|
end_image = gr.Image(type="pil", label="End Frame") |
|
|
|
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") |
|
|
|
with gr.Accordion("Advanced Settings", open=False, visible=True): |
|
duration = gr.Radio( |
|
[("Short (2s)", 33), ("Mid (4s)", 66)], |
|
value=33, |
|
label="Video Duration", |
|
visible=False |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,", |
|
visible=False |
|
) |
|
|
|
generate_button = gr.Button("Generate Video", variant="primary") |
|
|
|
with gr.Column(): |
|
output_video = gr.Video(label="Generated Video", autoplay=True) |
|
|
|
generate_button.click( |
|
fn=generate_video, |
|
inputs=[start_image, end_image, prompt, negative_prompt, duration], |
|
outputs=output_video |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["poli_tower.png", "tower_takes_off.png", "the man turns around"], |
|
["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"], |
|
["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"], |
|
], |
|
inputs=[start_image, end_image, prompt], |
|
outputs=output_video, |
|
fn=generate_video, |
|
cache_examples="lazy", |
|
) |
|
|
|
if __name__ == "__main__": |
|
app.launch(share=True) |