|
import torch |
|
import os |
|
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked |
|
from diffusers_helper.memory import DynamicSwapInstaller |
|
from .base_generator import BaseModelGenerator |
|
|
|
class F1ModelGenerator(BaseModelGenerator): |
|
""" |
|
Model generator for the F1 HunyuanVideo model. |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
""" |
|
Initialize the F1 model generator. |
|
""" |
|
super().__init__(**kwargs) |
|
self.model_name = "F1" |
|
self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503' |
|
self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503" |
|
|
|
def get_model_name(self): |
|
""" |
|
Get the name of the model. |
|
""" |
|
return self.model_name |
|
|
|
def load_model(self): |
|
""" |
|
Load the F1 transformer model. |
|
If offline mode is True, attempts to load from a local snapshot. |
|
""" |
|
print(f"Loading {self.model_name} Transformer...") |
|
|
|
path_to_load = self.model_path |
|
|
|
if self.offline: |
|
path_to_load = self._get_offline_load_path() |
|
|
|
|
|
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained( |
|
path_to_load, |
|
torch_dtype=torch.bfloat16 |
|
).cpu() |
|
|
|
|
|
self.transformer.eval() |
|
self.transformer.to(dtype=torch.bfloat16) |
|
self.transformer.requires_grad_(False) |
|
|
|
|
|
if not self.high_vram: |
|
DynamicSwapInstaller.install_model(self.transformer, device=self.gpu) |
|
else: |
|
|
|
self.transformer.to(device=self.gpu) |
|
|
|
print(f"{self.model_name} Transformer Loaded from {path_to_load}.") |
|
return self.transformer |
|
|
|
def prepare_history_latents(self, height, width): |
|
""" |
|
Prepare the history latents tensor for the F1 model. |
|
|
|
Args: |
|
height: The height of the image |
|
width: The width of the image |
|
|
|
Returns: |
|
The initialized history latents tensor |
|
""" |
|
return torch.zeros( |
|
size=(1, 16, 16 + 2 + 1, height // 8, width // 8), |
|
dtype=torch.float32 |
|
).cpu() |
|
|
|
def initialize_with_start_latent(self, history_latents, start_latent): |
|
""" |
|
Initialize the history latents with the start latent for the F1 model. |
|
|
|
Args: |
|
history_latents: The history latents |
|
start_latent: The start latent |
|
|
|
Returns: |
|
The initialized history latents |
|
""" |
|
|
|
return torch.cat([history_latents, start_latent.to(history_latents)], dim=2) |
|
|
|
def get_latent_paddings(self, total_latent_sections): |
|
""" |
|
Get the latent paddings for the F1 model. |
|
|
|
Args: |
|
total_latent_sections: The total number of latent sections |
|
|
|
Returns: |
|
A list of latent paddings |
|
""" |
|
|
|
return [1] * (total_latent_sections - 1) + [0] |
|
|
|
def prepare_indices(self, latent_padding_size, latent_window_size): |
|
""" |
|
Prepare the indices for the F1 model. |
|
|
|
Args: |
|
latent_padding_size: The size of the latent padding |
|
latent_window_size: The size of the latent window |
|
|
|
Returns: |
|
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices) |
|
""" |
|
|
|
|
|
effective_window_size = 5 if latent_window_size == 4.5 else int(latent_window_size) |
|
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0) |
|
clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1) |
|
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) |
|
|
|
return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices |
|
|
|
def prepare_clean_latents(self, start_latent, history_latents): |
|
""" |
|
Prepare the clean latents for the F1 model. |
|
|
|
Args: |
|
start_latent: The start latent |
|
history_latents: The history latents |
|
|
|
Returns: |
|
A tuple of (clean_latents, clean_latents_2x, clean_latents_4x) |
|
""" |
|
|
|
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2) |
|
|
|
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2) |
|
|
|
return clean_latents, clean_latents_2x, clean_latents_4x |
|
|
|
def update_history_latents(self, history_latents, generated_latents): |
|
""" |
|
Update the history latents with the generated latents for the F1 model. |
|
|
|
Args: |
|
history_latents: The history latents |
|
generated_latents: The generated latents |
|
|
|
Returns: |
|
The updated history latents |
|
""" |
|
|
|
return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2) |
|
|
|
def get_real_history_latents(self, history_latents, total_generated_latent_frames): |
|
""" |
|
Get the real history latents for the F1 model. |
|
|
|
Args: |
|
history_latents: The history latents |
|
total_generated_latent_frames: The total number of generated latent frames |
|
|
|
Returns: |
|
The real history latents |
|
""" |
|
|
|
return history_latents[:, :, -total_generated_latent_frames:, :, :] |
|
|
|
def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames): |
|
""" |
|
Update the history pixels with the current pixels for the F1 model. |
|
|
|
Args: |
|
history_pixels: The history pixels |
|
current_pixels: The current pixels |
|
overlapped_frames: The number of overlapped frames |
|
|
|
Returns: |
|
The updated history pixels |
|
""" |
|
from diffusers_helper.utils import soft_append_bcthw |
|
|
|
return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames) |
|
|
|
def get_section_latent_frames(self, latent_window_size, is_last_section): |
|
""" |
|
Get the number of section latent frames for the F1 model. |
|
|
|
Args: |
|
latent_window_size: The size of the latent window |
|
is_last_section: Whether this is the last section |
|
|
|
Returns: |
|
The number of section latent frames |
|
""" |
|
return latent_window_size * 2 |
|
|
|
def get_current_pixels(self, real_history_latents, section_latent_frames, vae): |
|
""" |
|
Get the current pixels for the F1 model. |
|
|
|
Args: |
|
real_history_latents: The real history latents |
|
section_latent_frames: The number of section latent frames |
|
vae: The VAE model |
|
|
|
Returns: |
|
The current pixels |
|
""" |
|
from diffusers_helper.hunyuan import vae_decode |
|
|
|
return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu() |
|
|
|
def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt): |
|
""" |
|
Format the position description for the F1 model. |
|
|
|
Args: |
|
total_generated_latent_frames: The total number of generated latent frames |
|
current_pos: The current position in seconds |
|
original_pos: The original position in seconds |
|
current_prompt: The current prompt |
|
|
|
Returns: |
|
The formatted position description |
|
""" |
|
return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, ' |
|
f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). ' |
|
f'Current position: {current_pos:.2f}s. ' |
|
f'using prompt: {current_prompt[:256]}...') |
|
|