File size: 9,000 Bytes
84669a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import torch
import os # for offline loading path
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.memory import DynamicSwapInstaller
from .base_generator import BaseModelGenerator
class F1ModelGenerator(BaseModelGenerator):
"""
Model generator for the F1 HunyuanVideo model.
"""
def __init__(self, **kwargs):
"""
Initialize the F1 model generator.
"""
super().__init__(**kwargs)
self.model_name = "F1"
self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503'
self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503"
def get_model_name(self):
"""
Get the name of the model.
"""
return self.model_name
def load_model(self):
"""
Load the F1 transformer model.
If offline mode is True, attempts to load from a local snapshot.
"""
print(f"Loading {self.model_name} Transformer...")
path_to_load = self.model_path # Initialize with the default path
if self.offline:
path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
# Create the transformer model
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
path_to_load,
torch_dtype=torch.bfloat16
).cpu()
# Configure the model
self.transformer.eval()
self.transformer.to(dtype=torch.bfloat16)
self.transformer.requires_grad_(False)
# Set up dynamic swap if not in high VRAM mode
if not self.high_vram:
DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
else:
# In high VRAM mode, move the entire model to GPU
self.transformer.to(device=self.gpu)
print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
return self.transformer
def prepare_history_latents(self, height, width):
"""
Prepare the history latents tensor for the F1 model.
Args:
height: The height of the image
width: The width of the image
Returns:
The initialized history latents tensor
"""
return torch.zeros(
size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
dtype=torch.float32
).cpu()
def initialize_with_start_latent(self, history_latents, start_latent):
"""
Initialize the history latents with the start latent for the F1 model.
Args:
history_latents: The history latents
start_latent: The start latent
Returns:
The initialized history latents
"""
# Add the start frame to history_latents
return torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
def get_latent_paddings(self, total_latent_sections):
"""
Get the latent paddings for the F1 model.
Args:
total_latent_sections: The total number of latent sections
Returns:
A list of latent paddings
"""
# F1 model uses a fixed approach with just 0 for last section and 1 for others
return [1] * (total_latent_sections - 1) + [0]
def prepare_indices(self, latent_padding_size, latent_window_size):
"""
Prepare the indices for the F1 model.
Args:
latent_padding_size: The size of the latent padding
latent_window_size: The size of the latent window
Returns:
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
"""
# F1 model uses a different indices approach
# latent_window_sizeが4.5の場合は特別に5を使用
effective_window_size = 5 if latent_window_size == 4.5 else int(latent_window_size)
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
def prepare_clean_latents(self, start_latent, history_latents):
"""
Prepare the clean latents for the F1 model.
Args:
start_latent: The start latent
history_latents: The history latents
Returns:
A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
"""
# For F1, we take the last frames for clean latents
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
# For F1, we prepend the start latent to clean_latents_1x
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
return clean_latents, clean_latents_2x, clean_latents_4x
def update_history_latents(self, history_latents, generated_latents):
"""
Update the history latents with the generated latents for the F1 model.
Args:
history_latents: The history latents
generated_latents: The generated latents
Returns:
The updated history latents
"""
# For F1, we append new frames to the end
return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
def get_real_history_latents(self, history_latents, total_generated_latent_frames):
"""
Get the real history latents for the F1 model.
Args:
history_latents: The history latents
total_generated_latent_frames: The total number of generated latent frames
Returns:
The real history latents
"""
# For F1, we take frames from the end
return history_latents[:, :, -total_generated_latent_frames:, :, :]
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
"""
Update the history pixels with the current pixels for the F1 model.
Args:
history_pixels: The history pixels
current_pixels: The current pixels
overlapped_frames: The number of overlapped frames
Returns:
The updated history pixels
"""
from diffusers_helper.utils import soft_append_bcthw
# For F1 model, history_pixels is first, current_pixels is second
return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
def get_section_latent_frames(self, latent_window_size, is_last_section):
"""
Get the number of section latent frames for the F1 model.
Args:
latent_window_size: The size of the latent window
is_last_section: Whether this is the last section
Returns:
The number of section latent frames
"""
return latent_window_size * 2
def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
"""
Get the current pixels for the F1 model.
Args:
real_history_latents: The real history latents
section_latent_frames: The number of section latent frames
vae: The VAE model
Returns:
The current pixels
"""
from diffusers_helper.hunyuan import vae_decode
# For F1, we take frames from the end
return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
"""
Format the position description for the F1 model.
Args:
total_generated_latent_frames: The total number of generated latent frames
current_pos: The current position in seconds
original_pos: The original position in seconds
current_prompt: The current prompt
Returns:
The formatted position description
"""
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
f'Current position: {current_pos:.2f}s. '
f'using prompt: {current_prompt[:256]}...')
|