Spaces:
Running
on
L40S
Running
on
L40S
import torch | |
import numpy as np | |
import json | |
import imageio | |
from PIL import Image | |
from torchvision.transforms import v2 | |
from einops import rearrange | |
import torchvision | |
import logging | |
from config import TEST_DATA_DIR | |
from camera_utils import Camera, parse_matrix, get_relative_pose | |
logger = logging.getLogger(__name__) | |
class VideoProcessor: | |
def __init__(self, pipe): | |
self.pipe = pipe | |
self.default_height = 480 | |
self.default_width = 832 | |
def crop_and_resize(self, image, height, width): | |
"""Crop and resize image to match target dimensions""" | |
width_img, height_img = image.size | |
scale = max(width / width_img, height / height_img) | |
image = torchvision.transforms.functional.resize( | |
image, | |
(round(height_img*scale), round(width_img*scale)), | |
interpolation=torchvision.transforms.InterpolationMode.BILINEAR | |
) | |
return image | |
def load_video_frames(self, video_path, num_frames=81, height=480, width=832): | |
"""Load and process video frames""" | |
reader = imageio.get_reader(video_path) | |
frames = [] | |
# Create frame processor with specified dimensions | |
frame_process = v2.Compose([ | |
v2.CenterCrop(size=(height, width)), | |
v2.Resize(size=(height, width), antialias=True), | |
v2.ToTensor(), | |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
for i in range(num_frames): | |
try: | |
frame = reader.get_data(i) | |
frame = Image.fromarray(frame) | |
frame = self.crop_and_resize(frame, height, width) | |
frame = frame_process(frame) | |
frames.append(frame) | |
except: | |
# If we run out of frames, repeat the last one | |
if frames: | |
frames.append(frames[-1]) | |
else: | |
raise ValueError("Video is too short!") | |
reader.close() | |
frames = torch.stack(frames, dim=0) | |
frames = rearrange(frames, "T C H W -> C T H W") | |
video_tensor = frames.unsqueeze(0) # Add batch dimension | |
return video_tensor | |
def load_camera_trajectory(self, cam_type, num_frames=81): | |
"""Load camera trajectory for the selected type""" | |
tgt_camera_path = "./camera_trajectories/camera_extrinsics.json" | |
with open(tgt_camera_path, 'r') as file: | |
cam_data = json.load(file) | |
# Get camera trajectory for selected type | |
cam_idx = list(range(num_frames))[::4] # Sample every 4 frames | |
traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx] | |
traj = np.stack(traj).transpose(0, 2, 1) | |
c2ws = [] | |
for c2w in traj: | |
c2w = c2w[:, [1, 2, 0, 3]] | |
c2w[:3, 1] *= -1. | |
c2w[:3, 3] /= 100 | |
c2ws.append(c2w) | |
tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] | |
relative_poses = [] | |
for i in range(len(tgt_cam_params)): | |
relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) | |
relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) | |
pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 | |
pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') | |
camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension | |
return camera_tensor | |
def process_video(self, video_path, text_prompt, cam_type, num_frames=81, height=480, width=832, seed=0, num_inference_steps=50, cfg_scale=5.0): | |
"""Process video through ReCamMaster model""" | |
# Load video frames | |
video_tensor = self.load_video_frames(video_path, num_frames, height, width) | |
# Load camera trajectory | |
camera_tensor = self.load_camera_trajectory(cam_type, num_frames) | |
# Generate video with ReCamMaster | |
video = self.pipe( | |
prompt=[text_prompt], | |
negative_prompt=["worst quality, low quality, blurry, jittery, distorted"], | |
source_video=video_tensor, | |
target_camera=camera_tensor, | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
cfg_scale=cfg_scale, | |
num_inference_steps=num_inference_steps, | |
seed=seed, | |
tiled=True | |
) | |
return video |