Spaces:
Paused
Paused
File size: 14,466 Bytes
05fcd0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
import torch
import os
import numpy as np
import math
import decord
from tqdm import tqdm
import pathlib
from PIL import Image
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.memory import DynamicSwapInstaller
from diffusers_helper.utils import resize_and_center_crop
from diffusers_helper.bucket_tools import find_nearest_bucket
from diffusers_helper.hunyuan import vae_encode, vae_decode
from .video_base_generator import VideoBaseModelGenerator
class VideoModelGenerator(VideoBaseModelGenerator):
"""
Generator for the Video (backward) extension of the Original HunyuanVideo model.
These generators accept video input instead of a single image.
"""
def __init__(self, **kwargs):
"""
Initialize the Video model generator.
"""
super().__init__(**kwargs)
self.model_name = "Video"
self.model_path = 'lllyasviel/FramePackI2V_HY' # Same as Original
self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY"
def get_latent_paddings(self, total_latent_sections):
"""
Get the latent paddings for the Video model.
Args:
total_latent_sections: The total number of latent sections
Returns:
A list of latent paddings
"""
# Video model uses reversed latent paddings like Original
if total_latent_sections > 4:
return [3] + [2] * (total_latent_sections - 3) + [1, 0]
else:
return list(reversed(range(total_latent_sections)))
def video_prepare_clean_latents_and_indices(self, end_frame_output_dimensions_latent, end_frame_weight, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames=5):
"""
Combined method to prepare clean latents and indices for the Video model.
Args:
Work in progress - better not to pass in latent_paddings and latent_padding.
num_cleaned_frames: Number of context frames to use from the video (adherence to video)
Returns:
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x)
"""
# Get num_cleaned_frames from job_params if available, otherwise use default value of 5
num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5
# HACK SOME STUFF IN THAT SHOULD NOT BE HERE
# Placeholders for end frame processing
# Colin, I'm only leaving them for the moment in case you want separate models for
# Video-backward and Video-backward-Endframe.
# end_latent = None
# end_of_input_video_embedding = None # Placeholder for end frame's CLIP embedding. SEE: 20250507 pftq: Process end frame if provided
# end_clip_embedding = None # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
# end_frame_weight = 0.0 # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
# HACK MORE STUFF IN THAT PROBABLY SHOULD BE ARGUMENTS OR OTHWISE MADE AVAILABLE
end_of_input_video_latent = video_latents[:, :, -1:] # Last frame of the input video (produced by video_encode in the PR)
is_start_of_video = latent_padding == 0 # This refers to the start of the *generated* video part
is_end_of_video = latent_padding == latent_paddings[0] # This refers to the end of the *generated* video part (closest to input video) (better not to pass in latent_paddings[])
# End of HACK STUFF
# Dynamic frame allocation for context frames (clean latents)
# This determines which frames from history_latents are used as input for the transformer.
available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2] # Use input video frames for first segment, else previously generated history
effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
if is_start_of_video:
effective_clean_frames = 1 # Avoid jumpcuts if input video is too different
clean_latent_pre_frames = effective_clean_frames
num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1
num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1
total_context_frames = num_2x_frames + num_4x_frames
total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames)
# Prepare indices for the transformer's input (these define the *relative positions* of different frame types in the input tensor)
# The total length is the sum of various frame types:
# clean_latent_pre_frames: frames before the blank/generated section
# latent_padding_size: blank frames before the generated section (for backward generation)
# latent_window_size: the new frames to be generated
# post_frames: frames after the generated section
# num_2x_frames, num_4x_frames: frames for lower resolution context
# 20250511 pftq: Dynamically adjust post_frames based on clean_latents_post
post_frames = 1 if is_end_of_video and end_frame_output_dimensions_latent is not None else effective_clean_frames # 20250511 pftq: Single frame for end_latent, otherwise padding causes still image
indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0)
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split(
[clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1
)
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) # Combined indices for 1x clean latents
# Prepare the *actual latent data* for the transformer's context inputs
# These are extracted from history_latents (or video_latents for the first segment)
context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :]
# clean_latents_4x: 4x downsampled context frames. From history_latents (or video_latents).
# clean_latents_2x: 2x downsampled context frames. From history_latents (or video_latents).
split_sizes = [num_4x_frames, num_2x_frames]
split_sizes = [s for s in split_sizes if s > 0]
if split_sizes and context_frames.shape[2] >= sum(split_sizes):
splits = context_frames.split(split_sizes, dim=2)
split_idx = 0
clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :]
split_idx += 1 if num_4x_frames > 0 else 0
clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :]
else:
clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :]
# clean_latents_pre: Latents from the *end* of the input video (if is_start_of_video), or previously generated history.
# Its purpose is to provide a smooth transition *from* the input video.
clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents)
# clean_latents_post: Latents from the *beginning* of the previously generated video segments.
# Its purpose is to provide a smooth transition *to* the existing generated content.
clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :]
# Special handling for the end frame:
# If it's the very first segment being generated (is_end_of_video in terms of generation order),
# and an end_latent was provided, force clean_latents_post to be that end_latent.
if is_end_of_video:
clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents) # Initialize to zero
# RT_BORG: end_of_input_video_embedding and end_clip_embedding shouldn't need to be checked, since they should
# always be provided if end_latent is provided. But bulletproofing before the release since test time will be short.
if end_frame_output_dimensions_latent is not None and end_of_input_video_embedding is not None and end_clip_embedding is not None:
# image_encoder_last_hidden_state: Weighted average of CLIP embedding of first input frame and end frame's CLIP embedding
# This guides the overall content to transition towards the end frame.
image_encoder_last_hidden_state = (1 - end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * end_frame_weight
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype)
if is_end_of_video:
# For the very first generated segment, the "post" part is the end_latent itself.
clean_latents_post = end_frame_output_dimensions_latent.to(history_latents)[:, :, :1, :, :] # Ensure single frame
# Pad clean_latents_pre/post if they have fewer frames than specified by clean_latent_pre_frames/post_frames
if clean_latents_pre.shape[2] < clean_latent_pre_frames:
clean_latents_pre = clean_latents_pre.repeat(1, 1, math.ceil(clean_latent_pre_frames / clean_latents_pre.shape[2]), 1, 1)[:,:,:clean_latent_pre_frames]
if clean_latents_post.shape[2] < post_frames:
clean_latents_post = clean_latents_post.repeat(1, 1, math.ceil(post_frames / clean_latents_post.shape[2]), 1, 1)[:,:,:post_frames]
# clean_latents: Concatenation of pre and post clean latents. These are the 1x resolution context frames.
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x
def update_history_latents(self, history_latents, generated_latents):
"""
Backward Generation: Update the history latents with the generated latents for the Video model.
Args:
history_latents: The history latents
generated_latents: The generated latents
Returns:
The updated history latents
"""
# For Video model, we prepend the generated latents to the front of history latents
# This matches the original implementation in video-example.py
# It generates new sections backwards in time, chunk by chunk
return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
def get_real_history_latents(self, history_latents, total_generated_latent_frames):
"""
Get the real history latents for the backward Video model. For Video, this is the first
`total_generated_latent_frames` frames of the history latents.
Args:
history_latents: The history latents
total_generated_latent_frames: The total number of generated latent frames
Returns:
The real history latents
"""
# Generated frames at the front.
return history_latents[:, :, :total_generated_latent_frames, :, :]
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
"""
Update the history pixels with the current pixels for the Video model.
Args:
history_pixels: The history pixels
current_pixels: The current pixels
overlapped_frames: The number of overlapped frames
Returns:
The updated history pixels
"""
from diffusers_helper.utils import soft_append_bcthw
# For Video model, we prepend the current pixels to the history pixels
# This matches the original implementation in video-example.py
return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
"""
Get the current pixels for the Video model.
Args:
real_history_latents: The real history latents
section_latent_frames: The number of section latent frames
vae: The VAE model
Returns:
The current pixels
"""
# For backward Video mode, current pixels are at the front of history.
return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
"""
Format the position description for the Video model.
Args:
total_generated_latent_frames: The total number of generated latent frames
current_pos: The current position in seconds (includes input video time)
original_pos: The original position in seconds
current_prompt: The current prompt
Returns:
The formatted position description
"""
# For Video model, current_pos already includes the input video time
# We just need to display the total generated frames and the current position
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
f'Generated video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
f'Current position: {current_pos:.2f}s (remaining: {original_pos:.2f}s). '
f'using prompt: {current_prompt[:256]}...')
|