|
import os |
|
import math |
|
from PIL import Image |
|
import cv2 |
|
from tqdm import tqdm |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from torchvision.utils import save_image |
|
import numpy as np |
|
from moviepy.editor import ImageSequenceClip |
|
|
|
from utils.mpi import mpi_rendering |
|
from utils.mpi.homography_sampler import HomographySample |
|
|
|
|
|
def image_to_tensor(img_path, unsqueeze=True): |
|
rgb = transforms.ToTensor()(Image.open(img_path)) |
|
if unsqueeze: |
|
rgb = rgb.unsqueeze(0) |
|
return rgb |
|
|
|
|
|
def disparity_to_tensor(disp_path, unsqueeze=True): |
|
disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1) |
|
disp = torch.from_numpy(disp)[None, ...] |
|
if unsqueeze: |
|
disp = disp.unsqueeze(0) |
|
return disp.float() |
|
|
|
|
|
def gen_swing_path(num_frames=90, r_x=0.14, r_y=0., r_z=0.10): |
|
"Return a list of matrix [4, 4]" |
|
t = torch.arange(num_frames) / (num_frames - 1) |
|
poses = torch.eye(4).repeat(num_frames, 1, 1) |
|
poses[:, 0, 3] = r_x * torch.sin(2. * math.pi * t) |
|
poses[:, 1, 3] = r_y * torch.cos(2. * math.pi * t) |
|
poses[:, 2, 3] = r_z * (torch.cos(2. * math.pi * t) - 1.) |
|
return poses.unbind() |
|
|
|
|
|
def render_3dphoto( |
|
src_imgs, |
|
mpi_all_src, |
|
disparity_all_src, |
|
k_src, |
|
k_tgt, |
|
save_path, |
|
): |
|
h, w = mpi_all_src.shape[-2:] |
|
device = mpi_all_src.device |
|
homography_sampler = HomographySample(h, w, device) |
|
k_src_inv = torch.inverse(k_src) |
|
|
|
|
|
xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity( |
|
homography_sampler.meshgrid, |
|
disparity_all_src, |
|
k_src_inv, |
|
) |
|
mpi_all_rgb_src = mpi_all_src[:, :, 0:3, :, :] |
|
mpi_all_sigma_src = mpi_all_src[:, :, 3:, :, :] |
|
_, _, blend_weights, _ = mpi_rendering.render( |
|
mpi_all_rgb_src, |
|
mpi_all_sigma_src, |
|
xyz_src_BS3HW, |
|
use_alpha=False, |
|
is_bg_depth_inf=False, |
|
) |
|
mpi_all_rgb_src = blend_weights * src_imgs.unsqueeze(1) + (1 - blend_weights) * mpi_all_rgb_src |
|
|
|
|
|
swing_path_list = gen_swing_path() |
|
frames = [] |
|
for cam_ext in tqdm(swing_path_list): |
|
frame = render_novel_view( |
|
mpi_all_rgb_src, |
|
mpi_all_sigma_src, |
|
disparity_all_src, |
|
cam_ext, |
|
k_src_inv, |
|
k_tgt, |
|
homography_sampler, |
|
) |
|
frame_np = frame[0].permute(1, 2, 0).contiguous().cpu().numpy() |
|
frame_np = np.clip(np.round(frame_np * 255), a_min=0, a_max=255).astype(np.uint8) |
|
frames.append(frame_np) |
|
rgb_clip = ImageSequenceClip(frames, fps=30) |
|
rgb_clip.write_videofile(save_path, verbose=False, codec='mpeg4', logger=None, bitrate='2000k') |
|
|
|
|
|
def render_novel_view( |
|
mpi_all_rgb_src, |
|
mpi_all_sigma_src, |
|
disparity_all_src, |
|
G_tgt_src, |
|
K_src_inv, |
|
K_tgt, |
|
homography_sampler, |
|
): |
|
xyz_src_BS3HW = mpi_rendering.get_src_xyz_from_plane_disparity( |
|
homography_sampler.meshgrid, |
|
disparity_all_src, |
|
K_src_inv |
|
) |
|
|
|
xyz_tgt_BS3HW = mpi_rendering.get_tgt_xyz_from_plane_disparity( |
|
xyz_src_BS3HW, |
|
G_tgt_src |
|
) |
|
|
|
tgt_imgs_syn, _, _ = mpi_rendering.render_tgt_rgb_depth( |
|
homography_sampler, |
|
mpi_all_rgb_src, |
|
mpi_all_sigma_src, |
|
disparity_all_src, |
|
xyz_tgt_BS3HW, |
|
G_tgt_src, |
|
K_src_inv, |
|
K_tgt, |
|
use_alpha=False, |
|
is_bg_depth_inf=False, |
|
) |
|
|
|
return tgt_imgs_syn |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self, name, fmt=":f"): |
|
self.name = name |
|
self.fmt = fmt |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
|
|
|
|
return f"{self.name:s}: {self.avg:.6f}" |
|
|