multimodalart's picture
Update app.py
bab37f9 verified
import os
# PyTorch 2.8 (temporary hack)
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# --- 1. Model Download and Setup (Diffusers Backend) ---
import spaces
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
import gc
from gradio_client import Client, handle_file # Import for API call
# Import the optimization function from the separate file
from optimization import optimize_pipeline_
# --- Constants and Model Loading ---
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
# --- NEW: Flexible Dimension Constants ---
MAX_DIMENSION = 832
MIN_DIMENSION = 480
DIMENSION_MULTIPLE = 16
SQUARE_SIZE = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
print("Loading models into memory. This may take a few minutes...")
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0)
pipe.to('cuda')
print("Optimizing pipeline...")
for i in range(3):
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
optimize_pipeline_(pipe,
image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
prompt='prompt',
height=MIN_DIMENSION,
width=MAX_DIMENSION,
num_frames=MAX_FRAMES_MODEL,
)
print("All models loaded and optimized. Gradio app is ready.")
# --- 2. Image Processing and Application Logic ---
def generate_end_frame(start_img, gen_prompt, progress=gr.Progress(track_tqdm=True)):
"""Calls an external Gradio API to generate an image."""
if start_img is None:
raise gr.Error("Please provide a Start Frame first.")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise gr.Error("HF_TOKEN not found in environment variables. Please set it in your Space secrets.")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
start_img.save(tmpfile.name)
tmp_path = tmpfile.name
progress(0.1, desc="Connecting to image generation API...")
client = Client("multimodalart/nano-banana")
progress(0.5, desc=f"Generating with prompt: '{gen_prompt}'...")
try:
result = client.predict(
prompt=gen_prompt,
images=[
{"image": handle_file(tmp_path)}
],
manual_token=hf_token,
api_name="/unified_image_generator"
)
finally:
os.remove(tmp_path)
progress(1.0, desc="Done!")
print(result)
return result
def switch_to_upload_tab():
"""Returns a gr.Tabs update to switch to the first tab."""
return gr.Tabs(selected="upload_tab")
def process_image_for_video(image: Image.Image) -> Image.Image:
"""
Resizes an image based on the following rules for video generation:
1. The longest side will be scaled down to MAX_DIMENSION if it's larger.
2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller.
3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE.
4. Square images are resized to a fixed SQUARE_SIZE.
The aspect ratio is preserved as closely as possible.
"""
width, height = image.size
# Rule 4: Handle square images
if width == height:
return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
# Determine target dimensions while preserving aspect ratio
aspect_ratio = width / height
new_width, new_height = width, height
# Rule 1: Scale down if too large
if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
if aspect_ratio > 1: # Landscape
scale = MAX_DIMENSION / new_width
else: # Portrait
scale = MAX_DIMENSION / new_height
new_width *= scale
new_height *= scale
# Rule 2: Scale up if too small
if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
if aspect_ratio > 1: # Landscape
scale = MIN_DIMENSION / new_height
else: # Portrait
scale = MIN_DIMENSION / new_width
new_width *= scale
new_height *= scale
# Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE
final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
# Ensure final dimensions are at least the minimum
final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
def resize_and_crop_to_match(target_image, reference_image):
"""Resizes and center-crops the target image to match the reference image's dimensions."""
ref_width, ref_height = reference_image.size
target_width, target_height = target_image.size
scale = max(ref_width / target_width, ref_height / target_height)
new_width, new_height = int(target_width * scale), int(target_height * scale)
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
return resized.crop((left, top, left + ref_width, top + ref_height))
@spaces.GPU(duration=120)
def generate_video(
start_image_pil,
end_image_pil,
prompt,
negative_prompt=default_negative_prompt,
duration_seconds=2.1,
steps=8,
guidance_scale=1,
guidance_scale_2=1,
seed=42,
randomize_seed=False,
progress=gr.Progress(track_tqdm=True)
):
"""
Generates a video by interpolating between a start and end image, guided by a text prompt,
using the diffusers Wan2.2 pipeline.
"""
if start_image_pil is None or end_image_pil is None:
raise gr.Error("Please upload both a start and an end image.")
progress(0.1, desc="Preprocessing images...")
# Step 1: Process the start image to get our target dimensions based on the new rules.
processed_start_image = process_image_for_video(start_image_pil)
# Step 2: Make the end image match the *exact* dimensions of the processed start image.
processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image)
target_height, target_width = processed_start_image.height, processed_start_image.width
# Handle seed and frame count
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...")
output_frames_list = pipe(
image=processed_start_image,
last_image=processed_end_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=target_height,
width=target_width,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
).frames[0]
progress(0.9, desc="Encoding and saving video...")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
progress(1.0, desc="Done!")
return video_path, current_seed
# --- 3. Gradio User Interface ---
css = '''
.fillable{max-width: 1100px !important}
.dark .progress-text {color: white}
#general_items{margin-top: 2em}
#group_all{overflow:visible}
#group_all .styler{overflow:visible}
#group_tabs .tabitem{padding: 0}
.tab-wrapper{margin-top: -33px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;}
#component-9-button{width: 50%;justify-content: center}
#component-11-button{width: 50%;justify-content: center}
#or_item{text-align: center; padding-top: 1em; padding-bottom: 1em; font-size: 1.1em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
#fivesec{margin-top: 5em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
'''
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA")
with gr.Row(elem_id="general_items"):
with gr.Column():
with gr.Group(elem_id="group_all"):
with gr.Row():
start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"])
# Capture the Tabs component in a variable and assign IDs to tabs
with gr.Tabs(elem_id="group_tabs") as tabs:
with gr.TabItem("Upload", id="upload_tab"):
end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"])
with gr.TabItem("Generate", id="generate_tab"):
generate_5seconds = gr.Button("Generate scene 5 seconds in the future", elem_id="fivesec")
gr.Markdown("Generate a custom end-frame with an edit model like [Nano Banana](https://huggingface.co/spaces/multimodalart/nano-banana) or [Qwen Image Edit](https://huggingface.co/spaces/multimodalart/Qwen-Image-Edit-Fast)", elem_id="or_item")
prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
with gr.Accordion("Advanced Settings", open=False):
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise")
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise")
with gr.Row():
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
output_video = gr.Video(label="Generated Video", autoplay=True)
# Main video generation button
ui_inputs = [
start_image,
end_image,
prompt,
negative_prompt_input,
duration_seconds_input,
steps_slider,
guidance_scale_input,
guidance_scale_2_input,
seed_input,
randomize_seed_checkbox
]
ui_outputs = [output_video, seed_input]
generate_button.click(
fn=generate_video,
inputs=ui_inputs,
outputs=ui_outputs
)
generate_5seconds.click(
fn=switch_to_upload_tab,
inputs=None,
outputs=[tabs]
).then(
fn=lambda img: generate_end_frame(img, "this image is a still frame from a movie. generate a new frame with what happens on this scene 5 seconds in the future"),
inputs=[start_image],
outputs=[end_image]
).success(
fn=generate_video,
inputs=ui_inputs,
outputs=ui_outputs
)
gr.Examples(
examples=[
["poli_tower.png", "tower_takes_off.png", "the man turns around"],
["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"],
["capyabara_zoomed.png", "capyabara.webp", "a dramatic dolly zoom"],
],
inputs=[start_image, end_image, prompt],
outputs=ui_outputs,
fn=generate_video,
cache_examples="lazy",
)
if __name__ == "__main__":
app.launch(share=True)