|
import torch |
|
import numpy as np |
|
from scipy.spatial.transform import Rotation |
|
|
|
|
|
def inverse(matrices): |
|
""" |
|
torch.inverse() sometimes produces outputs with nan the when batch size is 2. |
|
Ref https://github.com/pytorch/pytorch/issues/47272 |
|
this function keeps inversing the matrix until successful or maximum tries is reached |
|
:param matrices Bx3x3 |
|
""" |
|
inverse = None |
|
max_tries = 5 |
|
while (inverse is None) or (torch.isnan(inverse)).any(): |
|
|
|
inverse = torch.inverse(matrices) |
|
|
|
|
|
max_tries -= 1 |
|
if max_tries == 0: |
|
break |
|
|
|
|
|
if (torch.isnan(inverse)).any(): |
|
raise Exception("Matrix inverse contains nan!") |
|
return inverse |
|
|
|
|
|
class HomographySample: |
|
def __init__(self, H_tgt, W_tgt, device=None): |
|
if device is None: |
|
self.device = torch.device("cpu") |
|
else: |
|
self.device = device |
|
|
|
self.Height_tgt = H_tgt |
|
self.Width_tgt = W_tgt |
|
self.meshgrid = self.grid_generation(self.Height_tgt, self.Width_tgt, self.device) |
|
self.meshgrid = self.meshgrid.permute(2, 0, 1).contiguous() |
|
|
|
self.n = self.plane_normal_generation(self.device) |
|
|
|
@staticmethod |
|
def grid_generation(H, W, device): |
|
x = np.linspace(0, W-1, W) |
|
y = np.linspace(0, H-1, H) |
|
xv, yv = np.meshgrid(x, y) |
|
xv = torch.from_numpy(xv.astype(np.float32)).to(dtype=torch.float32, device=device) |
|
yv = torch.from_numpy(yv.astype(np.float32)).to(dtype=torch.float32, device=device) |
|
ones = torch.ones_like(xv) |
|
meshgrid = torch.stack((xv, yv, ones), dim=2) |
|
return meshgrid |
|
|
|
@staticmethod |
|
def plane_normal_generation(device): |
|
n = torch.tensor([0, 0, 1], dtype=torch.float32, device=device) |
|
return n |
|
|
|
@staticmethod |
|
def euler_to_rotation_matrix(x_angle, y_angle, z_angle, seq='xyz', degrees=False): |
|
""" |
|
Note that here we want to return a rotation matrix rot_mtx, which transform the tgt points into src frame, |
|
i.e, rot_mtx * p_tgt = p_src |
|
Therefore we need to add negative to x/y/z_angle |
|
:param roll: |
|
:param pitch: |
|
:param yaw: |
|
:return: |
|
""" |
|
r = Rotation.from_euler(seq, |
|
[-x_angle, -y_angle, -z_angle], |
|
degrees=degrees) |
|
rot_mtx = r.as_matrix().astype(np.float32) |
|
return rot_mtx |
|
|
|
|
|
def sample(self, src_BCHW, d_src_B, |
|
G_tgt_src, |
|
K_src_inv, K_tgt): |
|
""" |
|
Coordinate system: x, y are the image directions, z is pointing to depth direction |
|
:param src_BCHW: torch tensor float, 0-1, rgb/rgba. BxCxHxW |
|
Assume to be at position P=[I|0] |
|
:param d_src_B: distance of image plane to src camera origin |
|
:param G_tgt_src: Bx4x4 |
|
:param K_src_inv: Bx3x3 |
|
:param K_tgt: Bx3x3 |
|
:return: tgt_BCHW |
|
""" |
|
|
|
B, channels, Height_src, Width_src = src_BCHW.size(0), src_BCHW.size(1), src_BCHW.size(2), src_BCHW.size(3) |
|
R_tgt_src = G_tgt_src[:, 0:3, 0:3] |
|
t_tgt_src = G_tgt_src[:, 0:3, 3] |
|
|
|
Height_tgt = self.Height_tgt |
|
Width_tgt = self.Width_tgt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
R_tgt_src = R_tgt_src.to(device=src_BCHW.device) |
|
t_tgt_src = t_tgt_src.to(device=src_BCHW.device) |
|
K_src_inv = K_src_inv.to(device=src_BCHW.device) |
|
K_tgt = K_tgt.to(device=src_BCHW.device) |
|
|
|
|
|
|
|
|
|
n = self.n.to(device=src_BCHW.device) |
|
n = n.unsqueeze(0).repeat(B, 1) |
|
|
|
|
|
d_src_B33 = d_src_B.reshape(B, 1, 1).repeat(1, 3, 3) |
|
R_tnd = R_tgt_src - torch.matmul(t_tgt_src.unsqueeze(2), n.unsqueeze(1)) / -d_src_B33 |
|
H_tgt_src = torch.matmul(K_tgt, |
|
torch.matmul(R_tnd, K_src_inv)) |
|
|
|
|
|
with torch.no_grad(): |
|
H_src_tgt = inverse(H_tgt_src) |
|
|
|
|
|
meshgrid_tgt_homo = self.meshgrid.to(src_BCHW.device) |
|
|
|
meshgrid_tgt_homo = meshgrid_tgt_homo.unsqueeze(0).expand(B, 3, Height_tgt, Width_tgt) |
|
|
|
|
|
meshgrid_tgt_homo_B3N = meshgrid_tgt_homo.view(B, 3, -1) |
|
meshgrid_src_homo_B3N = torch.matmul(H_src_tgt, meshgrid_tgt_homo_B3N) |
|
|
|
meshgrid_src_homo = meshgrid_src_homo_B3N.view(B, 3, Height_tgt, Width_tgt).permute(0, 2, 3, 1) |
|
meshgrid_src = meshgrid_src_homo[:, :, :, 0:2] / meshgrid_src_homo[:, :, :, 2:] |
|
|
|
valid_mask_x = torch.logical_and(meshgrid_src[:, :, :, 0] < Width_src, |
|
meshgrid_src[:, :, :, 0] > -1) |
|
valid_mask_y = torch.logical_and(meshgrid_src[:, :, :, 1] < Height_src, |
|
meshgrid_src[:, :, :, 1] > -1) |
|
valid_mask = torch.logical_and(valid_mask_x, valid_mask_y) |
|
|
|
|
|
|
|
meshgrid_src[:, :, :, 0] = (meshgrid_src[:, :, :, 0]+0.5) / (Width_src * 0.5) - 1 |
|
meshgrid_src[:, :, :, 1] = (meshgrid_src[:, :, :, 1]+0.5) / (Height_src * 0.5) - 1 |
|
tgt_BCHW = torch.nn.functional.grid_sample(src_BCHW, grid=meshgrid_src, padding_mode='border', |
|
align_corners=False) |
|
|
|
return tgt_BCHW, valid_mask |
|
|