RT-MPINet / utils /mpi /homography_sampler.py
3ZadeSSG's picture
initial commit
ff00a24
import torch
import numpy as np
from scipy.spatial.transform import Rotation
def inverse(matrices):
"""
torch.inverse() sometimes produces outputs with nan the when batch size is 2.
Ref https://github.com/pytorch/pytorch/issues/47272
this function keeps inversing the matrix until successful or maximum tries is reached
:param matrices Bx3x3
"""
inverse = None
max_tries = 5
while (inverse is None) or (torch.isnan(inverse)).any():
#torch.cuda.synchronize()
inverse = torch.inverse(matrices)
# Break out of the loop when the inverse is successful or there"re no more tries
max_tries -= 1
if max_tries == 0:
break
# Raise an Exception if the inverse contains nan
if (torch.isnan(inverse)).any():
raise Exception("Matrix inverse contains nan!")
return inverse
class HomographySample:
def __init__(self, H_tgt, W_tgt, device=None):
if device is None:
self.device = torch.device("cpu")
else:
self.device = device
self.Height_tgt = H_tgt
self.Width_tgt = W_tgt
self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device)
self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous() # 3xHxW
self.n = self.plane_normal_generation(self.device)
@staticmethod
def grid_generation(H, W, device):
x = np.linspace(0, W-1, W)
y = np.linspace(0, H-1, H)
xv, yv = np.meshgrid(x, y) # HxW
xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device)
yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device)
ones = torch.ones_like(xv)
meshgrid = torch.stack((xv, yv, ones), dim=2) # HxWx3
return meshgrid
@staticmethod
def plane_normal_generation(device):
n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)
return n
@staticmethod
def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False):
"""
Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame,
i.e, rot_mtx * p_tgt = p_src
Therefore we need to add negative to x/y/z_angle
:param roll:
:param pitch:
:param yaw:
:return:
"""
r = Rotation.from_euler(seq,
[-x_angle, -y_angle, -z_angle],
degrees=degrees)
rot_mtx = r.as_matrix().astype(np.float32)
return rot_mtx
def sample(self, src_BCHW, d_src_B,
G_tgt_src,
K_src_inv, K_tgt):
"""
Coordinate system: x, y are the image directions, z is pointing to depth direction
:param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW
Assume to be at position P=[I|0]
:param d_src_B: distance of image plane to src camera origin
:param G_tgt_src: Bx4x4
:param K_src_inv: Bx3x3
:param K_tgt: Bx3x3
:return: tgt_BCHW
"""
# parameter processing ------ begin ------
B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3)
R_tgt_src = G_tgt_src[:, 0:3, 0:3]
t_tgt_src = G_tgt_src[:, 0:3, 3]
Height_tgt = self.Height_tgt
Width_tgt = self.Width_tgt
# if R_src_tgt is None:
# R_src_tgt = torch.eye(3, dtype=torch.float32, device=src_BCHW.device)
# R_src_tgt = R_src_tgt.unsqueeze(0).expand(B, 3, 3)
# if t_src_tgt is None:
# t_src_tgt = torch.tensor([0, 0, 0],
# dtype=torch.float32,
# device=src_BCHW.device)
# t_src_tgt = t_src_tgt.unsqueeze(0).expand(B, 3)
# relationship between FoV and focal length:
# assume W > H
# W / 2 = f*tan(\theta / 2)
# here we default the horizontal FoV as 53.13 degree
# the vertical FoV can be computed as H/2 = W*tan(\theta/2)
R_tgt_src = R_tgt_src.to(device=src_BCHW.device)
t_tgt_src = t_tgt_src.to(device=src_BCHW.device)
K_src_inv = K_src_inv.to(device=src_BCHW.device)
K_tgt = K_tgt.to(device=src_BCHW.device)
# parameter processing ------ end ------
# the goal is compute H_src_tgt, that maps a tgt pixel to src pixel
# so we compute H_tgt_src first, and then inverse
n = self.n.to(device=src_BCHW.device)
n = n.unsqueeze(0).repeat(B, 1) # Bx3
# Bx3x3 - (Bx3x1 * Bx1x3)
# note here we use -d_src, because the plane function is n^T * X - d_src = 0
d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3) # B -> Bx3x3
R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33
H_tgt_src = torch.matmul(K_tgt,
torch.matmul(R_tnd, K_src_inv))
# TODO: fix cuda inverse
with torch.no_grad():
H_src_tgt = inverse(H_tgt_src)
# create tgt image grid, and map to src
meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device)
# 3xHxW -> Bx3xHxW
meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt)
# wrap meshgrid_tgt_homo to meshgrid_src
meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1) # Bx3xHW
meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N) # Bx3x3 * Bx3xHW -> Bx3xHW
# Bx3xHW -> Bx3xHxW -> BxHxWx3
meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1)
meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:] # BxHxWx2
valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src,
meshgrid_src[:, :, :, 0] > -1)
valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src,
meshgrid_src[:, :, :, 1] > -1)
valid_mask = torch.logical_and(valid_mask_x, valid_mask_y) # BxHxW
# sample from src_BCHW
# normalize meshgrid_src to [-1,1]
meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1
meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1
tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border',
align_corners=False)
# BxCxHxW, BxHxW
return tgt_BCHW, valid_mask