Spaces:
Paused
Paused
import torch | |
import numpy as np | |
import traceback | |
from diffusers_helper.utils import save_bcthw_as_mp4 | |
def combine_videos_sequentially_from_tensors(processed_input_frames_np, | |
generated_frames_pt, | |
output_path, | |
target_fps, | |
crf_value): | |
""" | |
Combines processed input frames (NumPy) with generated frames (PyTorch Tensor) sequentially | |
and saves the result as an MP4 video using save_bcthw_as_mp4. | |
Args: | |
processed_input_frames_np: NumPy array of processed input frames (T_in, H, W_in, C), uint8. | |
generated_frames_pt: PyTorch tensor of generated frames (B_gen, C_gen, T_gen, H, W_gen), float32 [-1,1]. | |
(This will be history_pixels from worker.py) | |
output_path: Path to save the combined video. | |
target_fps: FPS for the output combined video. | |
crf_value: CRF value for video encoding. | |
Returns: | |
Path to the combined video, or None if an error occurs. | |
""" | |
try: | |
# 1. Convert processed_input_frames_np to PyTorch tensor BCTHW, float32, [-1,1] | |
# processed_input_frames_np shape: (T_in, H, W_in, C) | |
input_frames_pt = torch.from_numpy(processed_input_frames_np).float() / 127.5 - 1.0 # (T,H,W,C) | |
input_frames_pt = input_frames_pt.permute(3, 0, 1, 2) # (C,T,H,W) | |
input_frames_pt = input_frames_pt.unsqueeze(0) # (1,C,T,H,W) -> BCTHW | |
# Ensure generated_frames_pt is on the same device and dtype for concatenation | |
input_frames_pt = input_frames_pt.to(device=generated_frames_pt.device, dtype=generated_frames_pt.dtype) | |
# 2. Dimension Check (Heights and Widths should match) | |
# They should match, since the input frames should have been processed to match the generation resolution. | |
# But sanity check to ensure no mismatch occurs when the code is refactored. | |
if input_frames_pt.shape[3:] != generated_frames_pt.shape[3:]: # Compare (H,W) | |
print(f"Warning: Dimension mismatch for sequential combination! Input: {input_frames_pt.shape[3:]}, Generated: {generated_frames_pt.shape[3:]}.") | |
print("Attempting to proceed, but this might lead to errors or unexpected video output.") | |
# Potentially add resizing logic here if necessary, but for now, assume they match | |
# 3. Concatenate Tensors along the time dimension (dim=2 for BCTHW) | |
combined_video_pt = torch.cat([input_frames_pt, generated_frames_pt], dim=2) | |
# 4. Save | |
save_bcthw_as_mp4(combined_video_pt, output_path, fps=target_fps, crf=crf_value) | |
print(f"Sequentially combined video (from tensors) saved to {output_path}") | |
return output_path | |
except Exception as e: | |
print(f"Error in combine_videos_sequentially_from_tensors: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return None |