|
import torch |
|
import numpy as np |
|
import traceback |
|
|
|
from diffusers_helper.utils import save_bcthw_as_mp4 |
|
|
|
@torch.no_grad() |
|
def combine_videos_sequentially_from_tensors(processed_input_frames_np, |
|
generated_frames_pt, |
|
output_path, |
|
target_fps, |
|
crf_value): |
|
""" |
|
Combines processed input frames (NumPy) with generated frames (PyTorch Tensor) sequentially |
|
and saves the result as an MP4 video using save_bcthw_as_mp4. |
|
|
|
Args: |
|
processed_input_frames_np: NumPy array of processed input frames (T_in, H, W_in, C), uint8. |
|
generated_frames_pt: PyTorch tensor of generated frames (B_gen, C_gen, T_gen, H, W_gen), float32 [-1,1]. |
|
(This will be history_pixels from worker.py) |
|
output_path: Path to save the combined video. |
|
target_fps: FPS for the output combined video. |
|
crf_value: CRF value for video encoding. |
|
|
|
Returns: |
|
Path to the combined video, or None if an error occurs. |
|
""" |
|
try: |
|
|
|
|
|
input_frames_pt = torch.from_numpy(processed_input_frames_np).float() / 127.5 - 1.0 |
|
input_frames_pt = input_frames_pt.permute(3, 0, 1, 2) |
|
input_frames_pt = input_frames_pt.unsqueeze(0) |
|
|
|
|
|
input_frames_pt = input_frames_pt.to(device=generated_frames_pt.device, dtype=generated_frames_pt.dtype) |
|
|
|
|
|
|
|
|
|
if input_frames_pt.shape[3:] != generated_frames_pt.shape[3:]: |
|
print(f"Warning: Dimension mismatch for sequential combination! Input: {input_frames_pt.shape[3:]}, Generated: {generated_frames_pt.shape[3:]}.") |
|
print("Attempting to proceed, but this might lead to errors or unexpected video output.") |
|
|
|
|
|
|
|
combined_video_pt = torch.cat([input_frames_pt, generated_frames_pt], dim=2) |
|
|
|
|
|
save_bcthw_as_mp4(combined_video_pt, output_path, fps=target_fps, crf=crf_value) |
|
print(f"Sequentially combined video (from tensors) saved to {output_path}") |
|
return output_path |
|
except Exception as e: |
|
print(f"Error in combine_videos_sequentially_from_tensors: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return None |