ReCamMaster / video_processor.py
jbilcke-hf's picture
jbilcke-hf HF Staff
up
7917826
raw
history blame
4.57 kB
import torch
import numpy as np
import json
import imageio
from PIL import Image
from torchvision.transforms import v2
from einops import rearrange
import torchvision
import logging
from config import TEST_DATA_DIR
from camera_utils import Camera, parse_matrix, get_relative_pose
logger = logging.getLogger(__name__)
class VideoProcessor:
def __init__(self, pipe):
self.pipe = pipe
self.default_height = 480
self.default_width = 832
def crop_and_resize(self, image, height, width):
"""Crop and resize image to match target dimensions"""
width_img, height_img = image.size
scale = max(width / width_img, height / height_img)
image = torchvision.transforms.functional.resize(
image,
(round(height_img*scale), round(width_img*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_video_frames(self, video_path, num_frames=81, height=480, width=832):
"""Load and process video frames"""
reader = imageio.get_reader(video_path)
frames = []
# Create frame processor with specified dimensions
frame_process = v2.Compose([
v2.CenterCrop(size=(height, width)),
v2.Resize(size=(height, width), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
for i in range(num_frames):
try:
frame = reader.get_data(i)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, height, width)
frame = frame_process(frame)
frames.append(frame)
except:
# If we run out of frames, repeat the last one
if frames:
frames.append(frames[-1])
else:
raise ValueError("Video is too short!")
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
video_tensor = frames.unsqueeze(0) # Add batch dimension
return video_tensor
def load_camera_trajectory(self, cam_type, num_frames=81):
"""Load camera trajectory for the selected type"""
tgt_camera_path = "./camera_trajectories/camera_extrinsics.json"
with open(tgt_camera_path, 'r') as file:
cam_data = json.load(file)
# Get camera trajectory for selected type
cam_idx = list(range(num_frames))[::4] # Sample every 4 frames
traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx]
traj = np.stack(traj).transpose(0, 2, 1)
c2ws = []
for c2w in traj:
c2w = c2w[:, [1, 2, 0, 3]]
c2w[:3, 1] *= -1.
c2w[:3, 3] /= 100
c2ws.append(c2w)
tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
relative_poses = []
for i in range(len(tgt_cam_params)):
relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension
return camera_tensor
def process_video(self, video_path, text_prompt, cam_type, num_frames=81, height=480, width=832, seed=0, num_inference_steps=50, cfg_scale=5.0):
"""Process video through ReCamMaster model"""
# Load video frames
video_tensor = self.load_video_frames(video_path, num_frames, height, width)
# Load camera trajectory
camera_tensor = self.load_camera_trajectory(cam_type, num_frames)
# Generate video with ReCamMaster
video = self.pipe(
prompt=[text_prompt],
negative_prompt=["worst quality, low quality, blurry, jittery, distorted"],
source_video=video_tensor,
target_camera=camera_tensor,
height=height,
width=width,
num_frames=num_frames,
cfg_scale=cfg_scale,
num_inference_steps=num_inference_steps,
seed=seed,
tiled=True
)
return video