FPS-Studio / modules /pipelines /video_tools.py
rahul7star's picture
Migrated from GitHub
05fcd0f verified
import torch
import numpy as np
import traceback
from diffusers_helper.utils import save_bcthw_as_mp4
@torch.no_grad()
def combine_videos_sequentially_from_tensors(processed_input_frames_np,
generated_frames_pt,
output_path,
target_fps,
crf_value):
"""
Combines processed input frames (NumPy) with generated frames (PyTorch Tensor) sequentially
and saves the result as an MP4 video using save_bcthw_as_mp4.
Args:
processed_input_frames_np: NumPy array of processed input frames (T_in, H, W_in, C), uint8.
generated_frames_pt: PyTorch tensor of generated frames (B_gen, C_gen, T_gen, H, W_gen), float32 [-1,1].
(This will be history_pixels from worker.py)
output_path: Path to save the combined video.
target_fps: FPS for the output combined video.
crf_value: CRF value for video encoding.
Returns:
Path to the combined video, or None if an error occurs.
"""
try:
# 1. Convert processed_input_frames_np to PyTorch tensor BCTHW, float32, [-1,1]
# processed_input_frames_np shape: (T_in, H, W_in, C)
input_frames_pt = torch.from_numpy(processed_input_frames_np).float() / 127.5 - 1.0 # (T,H,W,C)
input_frames_pt = input_frames_pt.permute(3, 0, 1, 2) # (C,T,H,W)
input_frames_pt = input_frames_pt.unsqueeze(0) # (1,C,T,H,W) -> BCTHW
# Ensure generated_frames_pt is on the same device and dtype for concatenation
input_frames_pt = input_frames_pt.to(device=generated_frames_pt.device, dtype=generated_frames_pt.dtype)
# 2. Dimension Check (Heights and Widths should match)
# They should match, since the input frames should have been processed to match the generation resolution.
# But sanity check to ensure no mismatch occurs when the code is refactored.
if input_frames_pt.shape[3:] != generated_frames_pt.shape[3:]: # Compare (H,W)
print(f"Warning: Dimension mismatch for sequential combination! Input: {input_frames_pt.shape[3:]}, Generated: {generated_frames_pt.shape[3:]}.")
print("Attempting to proceed, but this might lead to errors or unexpected video output.")
# Potentially add resizing logic here if necessary, but for now, assume they match
# 3. Concatenate Tensors along the time dimension (dim=2 for BCTHW)
combined_video_pt = torch.cat([input_frames_pt, generated_frames_pt], dim=2)
# 4. Save
save_bcthw_as_mp4(combined_video_pt, output_path, fps=target_fps, crf=crf_value)
print(f"Sequentially combined video (from tensors) saved to {output_path}")
return output_path
except Exception as e:
print(f"Error in combine_videos_sequentially_from_tensors: {str(e)}")
import traceback
traceback.print_exc()
return None