import torch import numpy as np import json import imageio from PIL import Image from torchvision.transforms import v2 from einops import rearrange import torchvision import logging from config import TEST_DATA_DIR from camera_utils import Camera, parse_matrix, get_relative_pose logger = logging.getLogger(__name__) class VideoProcessor: def __init__(self, pipe): self.pipe = pipe self.default_height = 480 self.default_width = 832 def crop_and_resize(self, image, height, width): """Crop and resize image to match target dimensions""" width_img, height_img = image.size scale = max(width / width_img, height / height_img) image = torchvision.transforms.functional.resize( image, (round(height_img*scale), round(width_img*scale)), interpolation=torchvision.transforms.InterpolationMode.BILINEAR ) return image def load_video_frames(self, video_path, num_frames=81, height=480, width=832): """Load and process video frames""" reader = imageio.get_reader(video_path) frames = [] # Create frame processor with specified dimensions frame_process = v2.Compose([ v2.CenterCrop(size=(height, width)), v2.Resize(size=(height, width), antialias=True), v2.ToTensor(), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) for i in range(num_frames): try: frame = reader.get_data(i) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame, height, width) frame = frame_process(frame) frames.append(frame) except: # If we run out of frames, repeat the last one if frames: frames.append(frames[-1]) else: raise ValueError("Video is too short!") reader.close() frames = torch.stack(frames, dim=0) frames = rearrange(frames, "T C H W -> C T H W") video_tensor = frames.unsqueeze(0) # Add batch dimension return video_tensor def load_camera_trajectory(self, cam_type, num_frames=81): """Load camera trajectory for the selected type""" tgt_camera_path = "./camera_trajectories/camera_extrinsics.json" with open(tgt_camera_path, 'r') as file: cam_data = json.load(file) # Get camera trajectory for selected type cam_idx = list(range(num_frames))[::4] # Sample every 4 frames traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx] traj = np.stack(traj).transpose(0, 2, 1) c2ws = [] for c2w in traj: c2w = c2w[:, [1, 2, 0, 3]] c2w[:3, 1] *= -1. c2w[:3, 3] /= 100 c2ws.append(c2w) tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] relative_poses = [] for i in range(len(tgt_cam_params)): relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension return camera_tensor def process_video(self, video_path, text_prompt, cam_type, num_frames=81, height=480, width=832, seed=0, num_inference_steps=50, cfg_scale=5.0): """Process video through ReCamMaster model""" # Load video frames video_tensor = self.load_video_frames(video_path, num_frames, height, width) # Load camera trajectory camera_tensor = self.load_camera_trajectory(cam_type, num_frames) # Generate video with ReCamMaster video = self.pipe( prompt=[text_prompt], negative_prompt=["worst quality, low quality, blurry, jittery, distorted"], source_video=video_tensor, target_camera=camera_tensor, height=height, width=width, num_frames=num_frames, cfg_scale=cfg_scale, num_inference_steps=num_inference_steps, seed=seed, tiled=True ) return video