RT-MPINet / utils /mpi /rendering_utils.py
3ZadeSSG's picture
initial commit
ff00a24
import torch
def transform_G_xyz(G, xyz, is_return_homo=False):
"""
:param G: Bx4x4
:param xyz: Bx3xN
:return:
"""
assert len(G.size()) == len(xyz.size())
if len(G.size()) == 2:
G_B44 = G.unsqueeze(0)
xyz_B3N = xyz.unsqueeze(0)
else:
G_B44 = G
xyz_B3N = xyz
xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1)
G_xyz_B4N = torch.matmul(G_B44, xyz_B4N)
if is_return_homo:
return G_xyz_B4N
else:
return G_xyz_B4N[:, 0:3, :]
def gather_pixel_by_pxpy(img, pxpy):
"""
:param img: Bx3xHxW
:param pxpy: Bx2xN
:return:
"""
with torch.no_grad():
B, C, H, W = img.size()
if pxpy.dtype == torch.float32:
pxpy_int = torch.round(pxpy).to(torch.int64)
pxpy_int = pxpy_int.to(torch.int64)
pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1)
pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1)
pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt
rgb = torch.gather(img.view(B, C, H * W), dim=2,
index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt
return rgb
def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device):
"""
In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
:param start:
:param end:
:param num_bins:
:return:
"""
assert disparity_np[0] > disparity_np[-1]
S = disparity_np.shape[0] - 1
B = batch_size
bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1
interval = bin_edges[1:] - bin_edges[0:-1] # S
bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
# bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS
random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
disparity_array = bin_edges_start + interval * random_float
return disparity_array # BxS
def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device):
"""
In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
:param start:
:param end:
:param num_bins:
:return:
"""
assert start > end
B, S = batch_size, num_bins
bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1
interval = bin_edges[1] - bin_edges[0] # scalar
bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
# bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
disparity_array = bin_edges_start + interval * random_float
return disparity_array # BxS
def sample_pdf(values, weights, N_samples):
"""
draw samples from distribution approximated by values and weights.
the probability distribution can be denoted as weights = p(values)
:param values: Bx1xNxS
:param weights: Bx1xNxS
:param N_samples: number of sample to draw
:return:
"""
B, N, S = weights.size(0), weights.size(2), weights.size(3)
assert values.size() == (B, 1, N, S)
# convert values to bin edges
bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1
bin_edges = torch.cat((values[:, :, :, 0:1],
bin_edges,
values[:, :, :, -1:]), dim=3) # Bx1xNxS+1
pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS
cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS
cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device),
cdf), dim=3) # Bx1xNxS+1
# uniform sample over the cdf values
u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples
# get the index on the cdf array
cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples
cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples
cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples
# linear approximation for each bin
cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2)
cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4)
bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4)
# avoid zero cdf_intervals
cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples
u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
# there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling
t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5)
t = torch.where(cdf_intervals <= 1e-4,
torch.full_like(u_cdf_lower, 0.5),
t)
samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals
return samples