|
import torch |
|
import os |
|
import numpy as np |
|
import math |
|
import decord |
|
from tqdm import tqdm |
|
import pathlib |
|
from PIL import Image |
|
|
|
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked |
|
from diffusers_helper.memory import DynamicSwapInstaller |
|
from diffusers_helper.utils import resize_and_center_crop |
|
from diffusers_helper.bucket_tools import find_nearest_bucket |
|
from diffusers_helper.hunyuan import vae_encode, vae_decode |
|
from .base_generator import BaseModelGenerator |
|
|
|
class VideoBaseModelGenerator(BaseModelGenerator): |
|
""" |
|
Model generator for the Video extension of the Original HunyuanVideo model. |
|
This generator accepts video input instead of a single image. |
|
""" |
|
|
|
def __init__(self, **kwargs): |
|
""" |
|
Initialize the Video model generator. |
|
""" |
|
super().__init__(**kwargs) |
|
self.model_name = None |
|
self.model_path = None |
|
self.model_repo_id_for_cache = None |
|
self.full_video_latents = None |
|
self.resolution = 640 |
|
self.no_resize = False |
|
self.vae_batch_size = 16 |
|
|
|
|
|
try: |
|
import decord |
|
from tqdm import tqdm |
|
self.decord = decord |
|
self.tqdm = tqdm |
|
except ImportError: |
|
print("Warning: decord or tqdm not installed. Video processing will not work.") |
|
self.decord = None |
|
self.tqdm = None |
|
|
|
def get_model_name(self): |
|
""" |
|
Get the name of the model. |
|
""" |
|
return self.model_name |
|
|
|
def load_model(self): |
|
""" |
|
Load the Video transformer model. |
|
If offline mode is True, attempts to load from a local snapshot. |
|
""" |
|
print(f"Loading {self.model_name} Transformer...") |
|
|
|
path_to_load = self.model_path |
|
|
|
if self.offline: |
|
path_to_load = self._get_offline_load_path() |
|
|
|
|
|
self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained( |
|
path_to_load, |
|
torch_dtype=torch.bfloat16 |
|
).cpu() |
|
|
|
|
|
self.transformer.eval() |
|
self.transformer.to(dtype=torch.bfloat16) |
|
self.transformer.requires_grad_(False) |
|
|
|
|
|
if not self.high_vram: |
|
DynamicSwapInstaller.install_model(self.transformer, device=self.gpu) |
|
else: |
|
|
|
self.transformer.to(device=self.gpu) |
|
|
|
print(f"{self.model_name} Transformer Loaded from {path_to_load}.") |
|
return self.transformer |
|
|
|
def min_real_frames_to_encode(self, real_frames_available_count): |
|
""" |
|
Minimum number of real frames to encode |
|
is the maximum number of real frames used for generation context. |
|
|
|
The number of latents could be calculated as below for video F1, but keeping it simple for now |
|
by hardcoding the Video F1 value at max_latents_used_for_context = 27. |
|
|
|
# Calculate the number of latent frames to encode from the end of the input video |
|
num_frames_to_encode_from_end = 1 # Default minimum |
|
if model_type == "Video": |
|
# Max needed is 1 (clean_latent_pre) + 2 (max 2x) + 16 (max 4x) = 19 |
|
num_frames_to_encode_from_end = 19 |
|
elif model_type == "Video F1": |
|
ui_num_cleaned_frames = job_params.get('num_cleaned_frames', 5) |
|
# Max effective_clean_frames based on VideoF1ModelGenerator's logic. |
|
# Max num_clean_frames from UI is 10 (modules/interface.py). |
|
# Max effective_clean_frames = 10 - 1 = 9. |
|
# total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames |
|
# Max needed = 16 (max 4x) + 2 (max 2x) + 9 (max effective_clean_frames) = 27 |
|
num_frames_to_encode_from_end = 27 |
|
|
|
Note: 27 latents ~ 108 real frames = 3.6 seconds at 30 FPS. |
|
Note: 19 latents ~ 76 real frames ~ 2.5 seconds at 30 FPS. |
|
""" |
|
|
|
max_latents_used_for_context = 27 |
|
if self.get_model_name() == "Video": |
|
max_latents_used_for_context = 27 |
|
elif self.get_model_name() == "Video F1": |
|
max_latents_used_for_context = 27 |
|
else: |
|
print("======================================================") |
|
print(f" ***** Warning: Unsupported video extension model type: {self.get_model_name()}.") |
|
print( " ***** Using default max latents {max_latents_used_for_context} for context.") |
|
print( " ***** Please report to the developers if you see this message:") |
|
print( " ***** Discord: https://discord.gg/8Z2c3a4 or GitHub: https://github.com/colinurbs/FramePack-Studio") |
|
print("======================================================") |
|
|
|
|
|
|
|
latent_size_factor = 4 |
|
max_real_frames_used_for_context = max_latents_used_for_context * latent_size_factor |
|
|
|
|
|
trimmed_real_frames_count = min(real_frames_available_count, max_real_frames_used_for_context) |
|
if trimmed_real_frames_count < real_frames_available_count: |
|
print(f"Truncating video frames from {real_frames_available_count} to {trimmed_real_frames_count}, enough to populate context") |
|
|
|
|
|
frames_to_encode_count = (trimmed_real_frames_count // latent_size_factor) * latent_size_factor |
|
if frames_to_encode_count != trimmed_real_frames_count: |
|
print(f"Truncating video frames from {trimmed_real_frames_count} to {frames_to_encode_count}, for latent size compatibility") |
|
|
|
return frames_to_encode_count |
|
|
|
def extract_video_frames(self, is_for_encode, video_path, resolution, no_resize=False, input_files_dir=None): |
|
""" |
|
Extract real frames from a video, resized and center cropped as numpy array (T, H, W, C). |
|
|
|
Args: |
|
is_for_encode: If True, results are capped at maximum frames used for context, and aligned to 4-frame latent requirement. |
|
video_path: Path to the input video file. |
|
resolution: Target resolution for resizing frames. |
|
no_resize: Whether to use the original video resolution. |
|
input_files_dir: Directory for input files that won't be cleaned up. |
|
|
|
Returns: |
|
A tuple containing: |
|
- input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C) |
|
- fps: Frames per second of the input video |
|
- target_height: Target height of the video |
|
- target_width: Target width of the video |
|
""" |
|
def time_millis(): |
|
import time |
|
return time.perf_counter() * 1000.0 |
|
|
|
encode_start_time_millis = time_millis() |
|
|
|
|
|
video_path = str(pathlib.Path(video_path).resolve()) |
|
print(f"Processing video: {video_path}") |
|
|
|
|
|
if input_files_dir and "temp" in video_path: |
|
|
|
filename = os.path.basename(video_path) |
|
input_file_path = os.path.join(input_files_dir, filename) |
|
|
|
|
|
if os.path.exists(input_file_path): |
|
print(f"Using video from input_files_dir: {input_file_path}") |
|
video_path = input_file_path |
|
else: |
|
|
|
try: |
|
from diffusers_helper.utils import generate_timestamp |
|
safe_filename = f"{generate_timestamp()}_{filename}" |
|
input_file_path = os.path.join(input_files_dir, safe_filename) |
|
import shutil |
|
shutil.copy2(video_path, input_file_path) |
|
print(f"Copied video to input_files_dir: {input_file_path}") |
|
video_path = input_file_path |
|
except Exception as e: |
|
print(f"Error copying video to input_files_dir: {e}") |
|
|
|
try: |
|
|
|
print("Initializing VideoReader...") |
|
vr = decord.VideoReader(video_path) |
|
fps = vr.get_avg_fps() |
|
num_real_frames = len(vr) |
|
print(f"Video loaded: {num_real_frames} frames, FPS: {fps}") |
|
|
|
|
|
print("Reading video frames...") |
|
|
|
total_frames_in_video_file = len(vr) |
|
if is_for_encode: |
|
print(f"Using minimum real frames to encode: {self.min_real_frames_to_encode(total_frames_in_video_file)}") |
|
num_real_frames = self.min_real_frames_to_encode(total_frames_in_video_file) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_frame_index = total_frames_in_video_file - num_real_frames |
|
frame_indices_to_extract = range(start_frame_index, total_frames_in_video_file) |
|
frames = vr.get_batch(frame_indices_to_extract).asnumpy() |
|
|
|
print(f"Frames read: {frames.shape}") |
|
|
|
|
|
native_height, native_width = frames.shape[1], frames.shape[2] |
|
print(f"Native video resolution: {native_width}x{native_height}") |
|
|
|
|
|
target_height = native_height |
|
target_width = native_width |
|
|
|
|
|
if not no_resize: |
|
target_height, target_width = find_nearest_bucket(target_height, target_width, resolution=resolution) |
|
print(f"Adjusted resolution: {target_width}x{target_height}") |
|
else: |
|
print(f"Using native resolution without resizing: {target_width}x{target_height}") |
|
|
|
|
|
input_frames_resized_np = [] |
|
for i, frame in tqdm(enumerate(frames), desc="Processing Video Frames", total=num_real_frames, mininterval=0.1): |
|
frame_np = resize_and_center_crop(frame, target_width=target_width, target_height=target_height) |
|
input_frames_resized_np.append(frame_np) |
|
input_frames_resized_np = np.stack(input_frames_resized_np) |
|
print(f"Frames preprocessed: {input_frames_resized_np.shape}") |
|
|
|
resized_frames_time_millis = time_millis() |
|
if (False): |
|
print("======================================================") |
|
memory_bytes = input_frames_resized_np.nbytes |
|
memory_kb = memory_bytes / 1024 |
|
memory_mb = memory_kb / 1024 |
|
print(f" ***** input_frames_resized_np: {input_frames_resized_np.shape}") |
|
print(f" ***** Memory usage: {int(memory_mb)} MB") |
|
duration_ms = resized_frames_time_millis - encode_start_time_millis |
|
print(f" ***** Time taken to process frames tensor: {duration_ms / 1000.0:.2f} seconds") |
|
print("======================================================") |
|
|
|
return input_frames_resized_np, fps, target_height, target_width |
|
except Exception as e: |
|
print(f"Error in extract_video_frames: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def video_encode(self, video_path, resolution, no_resize=False, vae_batch_size=16, device=None, input_files_dir=None): |
|
""" |
|
Encode a video into latent representations using the VAE. |
|
|
|
Args: |
|
video_path: Path to the input video file. |
|
resolution: Target resolution for resizing frames. |
|
no_resize: Whether to use the original video resolution. |
|
vae_batch_size: Number of frames to process per batch. |
|
device: Device for computation (e.g., "cuda"). |
|
input_files_dir: Directory for input files that won't be cleaned up. |
|
|
|
Returns: |
|
A tuple containing: |
|
- start_latent: Latent of the first frame |
|
- input_image_np: First frame as numpy array |
|
- history_latents: Latents of all frames |
|
- fps: Frames per second of the input video |
|
- target_height: Target height of the video |
|
- target_width: Target width of the video |
|
- input_video_pixels: Video frames as tensor |
|
- end_of_input_video_image_np: Last frame as numpy array |
|
- input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C) |
|
""" |
|
encoding = True |
|
input_frames_resized_np, fps, target_height, target_width = self.extract_video_frames(encoding, video_path, resolution, no_resize, input_files_dir) |
|
|
|
try: |
|
if device is None: |
|
device = self.gpu |
|
|
|
|
|
if device == "cuda" and not torch.cuda.is_available(): |
|
print("CUDA is not available, falling back to CPU") |
|
device = "cpu" |
|
|
|
|
|
input_image_np = input_frames_resized_np[0] |
|
end_of_input_video_image_np = input_frames_resized_np[-1] |
|
|
|
|
|
print("Converting frames to tensor...") |
|
frames_pt = torch.from_numpy(input_frames_resized_np).float() / 127.5 - 1 |
|
frames_pt = frames_pt.permute(0, 3, 1, 2) |
|
frames_pt = frames_pt.unsqueeze(0) |
|
frames_pt = frames_pt.permute(0, 2, 1, 3, 4) |
|
print(f"Tensor shape: {frames_pt.shape}") |
|
|
|
|
|
input_video_pixels = frames_pt.cpu() |
|
|
|
|
|
print(f"Moving tensor to device: {device}") |
|
frames_pt = frames_pt.to(device) |
|
print("Tensor moved to device") |
|
|
|
|
|
print(f"Moving VAE to device: {device}") |
|
self.vae.to(device) |
|
print("VAE moved to device") |
|
|
|
|
|
print(f"Encoding input video frames in VAE batch size {vae_batch_size}") |
|
latents = [] |
|
self.vae.eval() |
|
with torch.no_grad(): |
|
frame_count = frames_pt.shape[2] |
|
step_count = math.ceil(frame_count / vae_batch_size) |
|
for i in tqdm(range(0, frame_count, vae_batch_size), desc="Encoding video frames", total=step_count, mininterval=0.1): |
|
batch = frames_pt[:, :, i:i + vae_batch_size] |
|
try: |
|
|
|
if device == "cuda": |
|
free_mem = torch.cuda.memory_allocated() / 1024**3 |
|
batch_latent = vae_encode(batch, self.vae) |
|
|
|
if device == "cuda": |
|
torch.cuda.synchronize() |
|
latents.append(batch_latent) |
|
except RuntimeError as e: |
|
print(f"Error during VAE encoding: {str(e)}") |
|
if device == "cuda" and "out of memory" in str(e).lower(): |
|
print("CUDA out of memory, try reducing vae_batch_size or using CPU") |
|
raise |
|
|
|
|
|
print("Concatenating latents...") |
|
history_latents = torch.cat(latents, dim=2) |
|
print(f"History latents shape: {history_latents.shape}") |
|
|
|
|
|
start_latent = history_latents[:, :, :1] |
|
print(f"Start latent shape: {start_latent.shape}") |
|
|
|
if (False): |
|
print("======================================================") |
|
memory_bytes = history_latents.nbytes |
|
memory_kb = memory_bytes / 1024 |
|
memory_mb = memory_kb / 1024 |
|
print(f" ***** history_latents: {history_latents.shape}") |
|
print(f" ***** Memory usage: {int(memory_mb)} MB") |
|
print("======================================================") |
|
|
|
|
|
if device == "cuda": |
|
self.vae.to(self.cpu) |
|
torch.cuda.empty_cache() |
|
print("VAE moved back to CPU, CUDA cache cleared") |
|
|
|
return start_latent, input_image_np, history_latents, fps, target_height, target_width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np |
|
|
|
except Exception as e: |
|
print(f"Error in video_encode: {str(e)}") |
|
raise |
|
|
|
|
|
|
|
|
|
def prepare_history_latents(self, height, width): |
|
""" |
|
Prepare the history latents tensor for the Video model. |
|
|
|
Args: |
|
height: The height of the image |
|
width: The width of the image |
|
|
|
Returns: |
|
The initialized history latents tensor |
|
""" |
|
raise TypeError( |
|
f"Error: '{self.__class__.__name__}.prepare_history_latents' should not be called " |
|
"on the Video models. history_latents should be initialized within worker() for all Video models " |
|
"as history_latents = video_latents." |
|
) |
|
|
|
def prepare_indices(self, latent_padding_size, latent_window_size): |
|
""" |
|
Prepare the indices for the Video model. |
|
|
|
Args: |
|
latent_padding_size: The size of the latent padding |
|
latent_window_size: The size of the latent window |
|
|
|
Returns: |
|
A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices) |
|
""" |
|
raise TypeError( |
|
f"Error: '{self.__class__.__name__}.prepare_indices' should not be called " |
|
"on the Video models. Currently video models each have a combined method: <model>_prepare_clean_latents_and_indices." |
|
) |
|
|
|
def set_full_video_latents(self, video_latents): |
|
""" |
|
Set the full video latents for context. |
|
|
|
Args: |
|
video_latents: The full video latents |
|
""" |
|
self.full_video_latents = video_latents |
|
|
|
def prepare_clean_latents(self, start_latent, history_latents): |
|
""" |
|
Prepare the clean latents for the Video model. |
|
|
|
Args: |
|
start_latent: The start latent |
|
history_latents: The history latents |
|
|
|
Returns: |
|
A tuple of (clean_latents, clean_latents_2x, clean_latents_4x) |
|
""" |
|
raise TypeError( |
|
f"Error: '{self.__class__.__name__}.prepare_indices' should not be called " |
|
"on the Video models. Currently video models each have a combined method: <model>_prepare_clean_latents_and_indices." |
|
) |
|
|
|
def get_section_latent_frames(self, latent_window_size, is_last_section): |
|
""" |
|
Get the number of section latent frames for the Video model. |
|
|
|
Args: |
|
latent_window_size: The size of the latent window |
|
is_last_section: Whether this is the last section |
|
|
|
Returns: |
|
The number of section latent frames |
|
""" |
|
return latent_window_size * 2 |
|
|
|
def combine_videos(self, source_video_path, generated_video_path, output_path): |
|
""" |
|
Combine the source video with the generated video side by side. |
|
|
|
Args: |
|
source_video_path: Path to the source video |
|
generated_video_path: Path to the generated video |
|
output_path: Path to save the combined video |
|
|
|
Returns: |
|
Path to the combined video |
|
""" |
|
try: |
|
import os |
|
import subprocess |
|
|
|
print(f"Combining source video {source_video_path} with generated video {generated_video_path}") |
|
|
|
|
|
from modules.toolbox.toolbox_processor import VideoProcessor |
|
from modules.toolbox.message_manager import MessageManager |
|
|
|
|
|
message_manager = MessageManager() |
|
|
|
|
|
try: |
|
from __main__ import settings |
|
video_processor = VideoProcessor(message_manager, settings.settings) |
|
except ImportError: |
|
|
|
from modules.settings import Settings |
|
settings = Settings() |
|
video_processor = VideoProcessor(message_manager, settings.settings) |
|
|
|
|
|
ffmpeg_exe = video_processor.ffmpeg_exe |
|
|
|
if not ffmpeg_exe: |
|
print("FFmpeg executable not found. Cannot combine videos.") |
|
return None |
|
|
|
print(f"Using ffmpeg at: {ffmpeg_exe}") |
|
|
|
|
|
import tempfile |
|
temp_dir = tempfile.gettempdir() |
|
filter_script_path = os.path.join(temp_dir, f"filter_script_{os.path.basename(output_path)}.txt") |
|
|
|
|
|
def get_video_info(video_path): |
|
cmd = [ |
|
ffmpeg_exe, "-i", video_path, |
|
"-hide_banner", "-loglevel", "error" |
|
] |
|
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
|
|
|
width = height = None |
|
for line in result.stderr.split('\n'): |
|
if 'Video:' in line: |
|
|
|
import re |
|
match = re.search(r'(\d+)x(\d+)', line) |
|
if match: |
|
width = int(match.group(1)) |
|
height = int(match.group(2)) |
|
break |
|
|
|
return width, height |
|
|
|
|
|
source_width, source_height = get_video_info(source_video_path) |
|
generated_width, generated_height = get_video_info(generated_video_path) |
|
|
|
if not source_width or not generated_width: |
|
print("Error: Could not determine video dimensions") |
|
return None |
|
|
|
print(f"Source video: {source_width}x{source_height}") |
|
print(f"Generated video: {generated_width}x{generated_height}") |
|
|
|
|
|
target_height = max(source_height, generated_height) |
|
source_target_width = int(source_width * (target_height / source_height)) |
|
generated_target_width = int(generated_width * (target_height / generated_height)) |
|
|
|
|
|
filter_complex = ( |
|
f"[0:v]scale={source_target_width}:{target_height}[left];" |
|
f"[1:v]scale={generated_target_width}:{target_height}[right];" |
|
f"[left]drawtext=text='Source':x=({source_target_width}/2-50):y=20:fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5[left_text];" |
|
f"[right]drawtext=text='Generated':x=({generated_target_width}/2-70):y=20:fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5[right_text];" |
|
f"[left_text][right_text]hstack=inputs=2[v]" |
|
) |
|
|
|
|
|
with open(filter_script_path, 'w') as f: |
|
f.write(filter_complex) |
|
|
|
|
|
cmd = [ |
|
ffmpeg_exe, "-y", |
|
"-i", source_video_path, |
|
"-i", generated_video_path, |
|
"-filter_complex_script", filter_script_path, |
|
"-map", "[v]" |
|
] |
|
|
|
|
|
has_audio_cmd = [ |
|
ffmpeg_exe, "-i", source_video_path, |
|
"-hide_banner", "-loglevel", "error" |
|
] |
|
audio_check = subprocess.run(has_audio_cmd, capture_output=True, text=True) |
|
has_audio = "Audio:" in audio_check.stderr |
|
|
|
if has_audio: |
|
cmd.extend(["-map", "0:a"]) |
|
|
|
|
|
cmd.extend([ |
|
"-c:v", "libx264", |
|
"-crf", "18", |
|
"-preset", "medium" |
|
]) |
|
|
|
if has_audio: |
|
cmd.extend(["-c:a", "aac"]) |
|
|
|
cmd.append(output_path) |
|
|
|
|
|
print(f"Running ffmpeg command: {' '.join(cmd)}") |
|
subprocess.run(cmd, check=True, capture_output=True, text=True) |
|
|
|
|
|
if os.path.exists(filter_script_path): |
|
os.remove(filter_script_path) |
|
|
|
print(f"Combined video saved to {output_path}") |
|
return output_path |
|
|
|
except Exception as e: |
|
print(f"Error combining videos: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return None |
|
|