|
import torch |
|
import os |
|
import numpy as np |
|
import math |
|
import decord |
|
from tqdm import tqdm |
|
import pathlib |
|
from PIL import Image |
|
|
|
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked |
|
from diffusers_helper.memory import DynamicSwapInstaller |
|
from diffusers_helper.utils import resize_and_center_crop |
|
from diffusers_helper.bucket_tools import find_nearest_bucket |
|
from diffusers_helper.hunyuan import vae_encode, vae_decode |
|
from .video_base_generator import VideoBaseModelGenerator |
|
|
|
class VideoModelGenerator(VideoBaseModelGenerator): |
|
""" |
|
Generator for the Video (backward) extension of the Original HunyuanVideo model. |
|
These generators accept video input instead of a single image. |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
""" |
|
Initialize the Video model generator. |
|
""" |
|
super().__init__(**kwargs) |
|
self.model_name = "Video" |
|
self.model_path = 'lllyasviel/FramePackI2V_HY' |
|
self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY" |
|
|
|
def get_latent_paddings(self, total_latent_sections): |
|
""" |
|
Get the latent paddings for the Video model. |
|
|
|
Args: |
|
total_latent_sections: The total number of latent sections |
|
|
|
Returns: |
|
A list of latent paddings |
|
""" |
|
|
|
if total_latent_sections > 4: |
|
return [3] + [2] * (total_latent_sections - 3) + [1, 0] |
|
else: |
|
return list(reversed(range(total_latent_sections))) |
|
|
|
def video_prepare_clean_latents_and_indices(self, end_frame_output_dimensions_latent, end_frame_weight, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames=5): |
|
""" |
|
Combined method to prepare clean latents and indices for the Video model. |
|
|
|
Args: |
|
Work in progress - better not to pass in latent_paddings and latent_padding. |
|
num_cleaned_frames: Number of context frames to use from the video (adherence to video) |
|
|
|
Returns: |
|
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x) |
|
""" |
|
|
|
num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end_of_input_video_latent = video_latents[:, :, -1:] |
|
is_start_of_video = latent_padding == 0 |
|
is_end_of_video = latent_padding == latent_paddings[0] |
|
|
|
|
|
|
|
|
|
available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2] |
|
effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1 |
|
if is_start_of_video: |
|
effective_clean_frames = 1 |
|
|
|
clean_latent_pre_frames = effective_clean_frames |
|
num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1 |
|
num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1 |
|
total_context_frames = num_2x_frames + num_4x_frames |
|
total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
post_frames = 1 if is_end_of_video and end_frame_output_dimensions_latent is not None else effective_clean_frames |
|
indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0) |
|
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split( |
|
[clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1 |
|
) |
|
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) |
|
|
|
|
|
|
|
context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :] |
|
|
|
|
|
split_sizes = [num_4x_frames, num_2x_frames] |
|
split_sizes = [s for s in split_sizes if s > 0] |
|
if split_sizes and context_frames.shape[2] >= sum(split_sizes): |
|
splits = context_frames.split(split_sizes, dim=2) |
|
split_idx = 0 |
|
clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :] |
|
split_idx += 1 if num_4x_frames > 0 else 0 |
|
clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :] |
|
else: |
|
clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :] |
|
|
|
|
|
|
|
clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents) |
|
|
|
|
|
|
|
clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :] |
|
|
|
|
|
|
|
|
|
if is_end_of_video: |
|
clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents) |
|
|
|
|
|
|
|
if end_frame_output_dimensions_latent is not None and end_of_input_video_embedding is not None and end_clip_embedding is not None: |
|
|
|
|
|
image_encoder_last_hidden_state = (1 - end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * end_frame_weight |
|
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype) |
|
|
|
if is_end_of_video: |
|
|
|
clean_latents_post = end_frame_output_dimensions_latent.to(history_latents)[:, :, :1, :, :] |
|
|
|
|
|
if clean_latents_pre.shape[2] < clean_latent_pre_frames: |
|
clean_latents_pre = clean_latents_pre.repeat(1, 1, math.ceil(clean_latent_pre_frames / clean_latents_pre.shape[2]), 1, 1)[:,:,:clean_latent_pre_frames] |
|
if clean_latents_post.shape[2] < post_frames: |
|
clean_latents_post = clean_latents_post.repeat(1, 1, math.ceil(post_frames / clean_latents_post.shape[2]), 1, 1)[:,:,:post_frames] |
|
|
|
|
|
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) |
|
|
|
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x |
|
|
|
def update_history_latents(self, history_latents, generated_latents): |
|
""" |
|
Backward Generation: Update the history latents with the generated latents for the Video model. |
|
|
|
Args: |
|
history_latents: The history latents |
|
generated_latents: The generated latents |
|
|
|
Returns: |
|
The updated history latents |
|
""" |
|
|
|
|
|
|
|
return torch.cat([generated_latents.to(history_latents), history_latents], dim=2) |
|
|
|
def get_real_history_latents(self, history_latents, total_generated_latent_frames): |
|
""" |
|
Get the real history latents for the backward Video model. For Video, this is the first |
|
`total_generated_latent_frames` frames of the history latents. |
|
|
|
Args: |
|
history_latents: The history latents |
|
total_generated_latent_frames: The total number of generated latent frames |
|
|
|
Returns: |
|
The real history latents |
|
""" |
|
|
|
return history_latents[:, :, :total_generated_latent_frames, :, :] |
|
|
|
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames): |
|
""" |
|
Update the history pixels with the current pixels for the Video model. |
|
|
|
Args: |
|
history_pixels: The history pixels |
|
current_pixels: The current pixels |
|
overlapped_frames: The number of overlapped frames |
|
|
|
Returns: |
|
The updated history pixels |
|
""" |
|
from diffusers_helper.utils import soft_append_bcthw |
|
|
|
|
|
return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) |
|
|
|
def get_current_pixels(self, real_history_latents, section_latent_frames, vae): |
|
""" |
|
Get the current pixels for the Video model. |
|
|
|
Args: |
|
real_history_latents: The real history latents |
|
section_latent_frames: The number of section latent frames |
|
vae: The VAE model |
|
|
|
Returns: |
|
The current pixels |
|
""" |
|
|
|
return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu() |
|
|
|
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt): |
|
""" |
|
Format the position description for the Video model. |
|
|
|
Args: |
|
total_generated_latent_frames: The total number of generated latent frames |
|
current_pos: The current position in seconds (includes input video time) |
|
original_pos: The original position in seconds |
|
current_prompt: The current prompt |
|
|
|
Returns: |
|
The formatted position description |
|
""" |
|
|
|
|
|
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, ' |
|
f'Generated video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). ' |
|
f'Current position: {current_pos:.2f}s (remaining: {original_pos:.2f}s). ' |
|
f'using prompt: {current_prompt[:256]}...') |
|
|