Pi3 / pi3 /utils /geometry.py
yyfz233's picture
Initial commit
853528a
import numpy as np
import torch
import torch.nn.functional as F
def se3_inverse(T):
"""
Computes the inverse of a batch of SE(3) matrices.
T: Tensor of shape (B, 4, 4)
"""
if len(T.shape) == 2:
T = T[None]
unseq_flag = True
else:
unseq_flag = False
if torch.is_tensor(T):
R = T[:, :3, :3]
t = T[:, :3, 3].unsqueeze(-1)
R_inv = R.transpose(-2, -1)
t_inv = -torch.matmul(R_inv, t)
T_inv = torch.cat([
torch.cat([R_inv, t_inv], dim=-1),
torch.tensor([0, 0, 0, 1], device=T.device, dtype=T.dtype).repeat(T.shape[0], 1, 1)
], dim=1)
else:
R = T[:, :3, :3]
t = T[:, :3, 3, np.newaxis]
R_inv = np.swapaxes(R, -2, -1)
t_inv = -R_inv @ t
bottom_row = np.zeros((T.shape[0], 1, 4), dtype=T.dtype)
bottom_row[:, :, 3] = 1
top_part = np.concatenate([R_inv, t_inv], axis=-1)
T_inv = np.concatenate([top_part, bottom_row], axis=1)
if unseq_flag:
T_inv = T_inv[0]
return T_inv
def get_pixel(H, W):
# get 2D pixels (u, v) for image_a in cam_a pixel space
u_a, v_a = np.meshgrid(np.arange(W), np.arange(H))
# u_a = np.flip(u_a, axis=1)
# v_a = np.flip(v_a, axis=0)
pixels_a = np.stack([
u_a.flatten() + 0.5,
v_a.flatten() + 0.5,
np.ones_like(u_a.flatten())
], axis=0)
return pixels_a
def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, z_far=0, **kw):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
- camera_pose: a 4x3 or 4x4 cam2world matrix
Returns:
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
if z_far > 0:
valid_mask = valid_mask & (depthmap < z_far)
X_world = X_cam # default
if camera_pose is not None:
# R_cam2world = np.float32(camera_params["R_cam2world"])
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
R_cam2world = camera_pose[:3, :3]
t_cam2world = camera_pose[:3, 3]
# Express in absolute coordinates (invalid depth values)
X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
return X_world, valid_mask
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
"""
Args:
- depthmap (HxW array):
- camera_intrinsics: a 3x3 matrix
Returns:
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
"""
camera_intrinsics = np.float32(camera_intrinsics)
H, W = depthmap.shape
# Compute 3D ray associated with each pixel
# Strong assumption: there are no skew terms
# assert camera_intrinsics[0, 1] == 0.0
# assert camera_intrinsics[1, 0] == 0.0
if pseudo_focal is None:
fu = camera_intrinsics[0, 0]
fv = camera_intrinsics[1, 1]
else:
assert pseudo_focal.shape == (H, W)
fu = fv = pseudo_focal
cu = camera_intrinsics[0, 2]
cv = camera_intrinsics[1, 2]
u, v = np.meshgrid(np.arange(W), np.arange(H))
z_cam = depthmap
x_cam = (u - cu) * z_cam / fu
y_cam = (v - cv) * z_cam / fv
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
# Mask for valid coordinates
valid_mask = (depthmap > 0.0)
# Invalid any depth > 80m
valid_mask = valid_mask
return X_cam, valid_mask
def homogenize_points(
points,
):
"""Convert batched points (xyz) to (xyz1)."""
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
if H is None:
B,H,W = depth1.shape
else:
B = depth1.shape[0]
with torch.no_grad():
x1_n = torch.meshgrid(
*[
torch.linspace(
-1 + 1 / n, 1 - 1 / n, n, device=depth1.device
)
for n in (B, H, W)
],
indexing = 'ij'
)
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
mask, x2 = warp_kpts(
x1_n.double(),
depth1.double(),
depth2.double(),
T_1to2.double(),
K1.double(),
K2.double(),
depth_interpolation_mode = depth_interpolation_mode,
relative_depth_error_threshold = relative_depth_error_threshold,
)
prob = mask.float().reshape(B, H, W)
x2 = x2.reshape(B, H, W, 2)
return x2, prob
@torch.no_grad()
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
"""Warp kpts0 from I0 to I1 with depth, K and Rt
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
# https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
depth0 (torch.Tensor): [N, H, W],
depth1 (torch.Tensor): [N, H, W],
T_0to1 (torch.Tensor): [N, 3, 4],
K0 (torch.Tensor): [N, 3, 3],
K1 (torch.Tensor): [N, 3, 3],
Returns:
calculable_mask (torch.Tensor): [N, L]
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
"""
(
n,
h,
w,
) = depth0.shape
if depth_interpolation_mode == "combined":
# Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
if smooth_mask:
raise NotImplementedError("Combined bilinear and NN warp not implemented")
valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
smooth_mask = smooth_mask,
return_relative_depth_error = return_relative_depth_error,
depth_interpolation_mode = "bilinear",
relative_depth_error_threshold = relative_depth_error_threshold)
valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
smooth_mask = smooth_mask,
return_relative_depth_error = return_relative_depth_error,
depth_interpolation_mode = "nearest-exact",
relative_depth_error_threshold = relative_depth_error_threshold)
nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
warp = warp_bilinear.clone()
warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
valid = valid_bilinear | valid_nearest
return valid, warp
kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
:, 0, :, 0
]
kpts0 = torch.stack(
(w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
# Sample depth, get calculable_mask on depth != 0
# nonzero_mask = kpts0_depth != 0
# Sample depth, get calculable_mask on depth > 0
nonzero_mask = kpts0_depth > 0
# Unproject
kpts0_h = (
torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
* kpts0_depth[..., None]
) # (N, L, 3)
kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
kpts0_cam = kpts0_n
# Rigid Transform
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (
w_kpts0_h[:, :, [2]] + 1e-4
) # (N, L, 2), +1e-4 to avoid zero depth
# Covisible Check
h, w = depth1.shape[1:3]
covisible_mask = (
(w_kpts0[:, :, 0] > 0)
* (w_kpts0[:, :, 0] < w - 1)
* (w_kpts0[:, :, 1] > 0)
* (w_kpts0[:, :, 1] < h - 1)
)
w_kpts0 = torch.stack(
(2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
# w_kpts0[~covisible_mask, :] = -5 # xd
w_kpts0_depth = F.grid_sample(
depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
)[:, 0, :, 0]
relative_depth_error = (
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
).abs()
if not smooth_mask:
consistent_mask = relative_depth_error < relative_depth_error_threshold
else:
consistent_mask = (-relative_depth_error/smooth_mask).exp()
valid_mask = nonzero_mask * covisible_mask * consistent_mask
if return_relative_depth_error:
return relative_depth_error, w_kpts0
else:
return valid_mask, w_kpts0
def geotrf(Trf, pts, ncol=None, norm=False):
""" Apply a geometric transformation to a list of 3-D points.
H: 3x3 or 4x4 projection matrix (typically a Homography)
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
ncol: int. number of columns of the result (2 or 3)
norm: float. if != 0, the resut is projected on the z=norm plane.
Returns an array of projected 2d points.
"""
assert Trf.ndim >= 2
if isinstance(Trf, np.ndarray):
pts = np.asarray(pts)
elif isinstance(Trf, torch.Tensor):
pts = torch.as_tensor(pts, dtype=Trf.dtype)
# adapt shape if necessary
output_reshape = pts.shape[:-1]
ncol = ncol or pts.shape[-1]
# optimized code
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
Trf.ndim == 3 and pts.ndim == 4):
d = pts.shape[3]
if Trf.shape[-1] == d:
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
elif Trf.shape[-1] == d + 1:
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
else:
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
else:
if Trf.ndim >= 3:
n = Trf.ndim - 2
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
if pts.ndim > Trf.ndim:
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
elif pts.ndim == 2:
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
pts = pts[:, None, :]
if pts.shape[-1] + 1 == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
elif pts.shape[-1] == Trf.shape[-1]:
Trf = Trf.swapaxes(-1, -2) # transpose Trf
pts = pts @ Trf
else:
pts = Trf @ pts.T
if pts.ndim >= 2:
pts = pts.swapaxes(-1, -2)
if norm:
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
if norm != 1:
pts *= norm
res = pts[..., :ncol].reshape(*output_reshape, ncol)
return res
def inv(mat):
""" Invert a torch or numpy matrix
"""
if isinstance(mat, torch.Tensor):
return torch.linalg.inv(mat)
if isinstance(mat, np.ndarray):
return np.linalg.inv(mat)
raise ValueError(f'bad matrix type = {type(mat)}')
def opencv_camera_to_plucker(poses, K, H, W):
device = poses.device
B = poses.shape[0]
pixel = torch.from_numpy(get_pixel(H, W).astype(np.float32)).to(device).T.reshape(H, W, 3)[None].repeat(B, 1, 1, 1) # (3, H, W)
pixel = torch.einsum('bij, bhwj -> bhwi', torch.inverse(K), pixel)
ray_directions = torch.einsum('bij, bhwj -> bhwi', poses[..., :3, :3], pixel)
ray_origins = poses[..., :3, 3][:, None, None].repeat(1, H, W, 1)
ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
plucker_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
return plucker_ray
def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
"""
Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
Args:
depth (torch.Tensor): shape (..., height, width), linear depth map
atol (float): absolute tolerance
rtol (float): relative tolerance
Returns:
edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
"""
shape = depth.shape
depth = depth.reshape(-1, 1, *shape[-2:])
if mask is not None:
mask = mask.reshape(-1, 1, *shape[-2:])
if mask is None:
diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
else:
diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
edge = torch.zeros_like(depth, dtype=torch.bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= (diff / depth).nan_to_num_() > rtol
edge = edge.reshape(*shape)
return edge