File size: 5,499 Bytes
ff00a24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
def transform_G_xyz(G, xyz, is_return_homo=False):
"""
:param G: Bx4x4
:param xyz: Bx3xN
:return:
"""
assert len(G.size()) == len(xyz.size())
if len(G.size()) == 2:
G_B44 = G.unsqueeze(0)
xyz_B3N = xyz.unsqueeze(0)
else:
G_B44 = G
xyz_B3N = xyz
xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1)
G_xyz_B4N = torch.matmul(G_B44, xyz_B4N)
if is_return_homo:
return G_xyz_B4N
else:
return G_xyz_B4N[:, 0:3, :]
def gather_pixel_by_pxpy(img, pxpy):
"""
:param img: Bx3xHxW
:param pxpy: Bx2xN
:return:
"""
with torch.no_grad():
B, C, H, W = img.size()
if pxpy.dtype == torch.float32:
pxpy_int = torch.round(pxpy).to(torch.int64)
pxpy_int = pxpy_int.to(torch.int64)
pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1)
pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1)
pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt
rgb = torch.gather(img.view(B, C, H * W), dim=2,
index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt
return rgb
def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device):
"""
In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
:param start:
:param end:
:param num_bins:
:return:
"""
assert disparity_np[0] > disparity_np[-1]
S = disparity_np.shape[0] - 1
B = batch_size
bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1
interval = bin_edges[1:] - bin_edges[0:-1] # S
bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
# bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS
random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
disparity_array = bin_edges_start + interval * random_float
return disparity_array # BxS
def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device):
"""
In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far)
:param start:
:param end:
:param num_bins:
:return:
"""
assert start > end
B, S = batch_size, num_bins
bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1
interval = bin_edges[1] - bin_edges[0] # scalar
bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS
# bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS
random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS
disparity_array = bin_edges_start + interval * random_float
return disparity_array # BxS
def sample_pdf(values, weights, N_samples):
"""
draw samples from distribution approximated by values and weights.
the probability distribution can be denoted as weights = p(values)
:param values: Bx1xNxS
:param weights: Bx1xNxS
:param N_samples: number of sample to draw
:return:
"""
B, N, S = weights.size(0), weights.size(2), weights.size(3)
assert values.size() == (B, 1, N, S)
# convert values to bin edges
bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1
bin_edges = torch.cat((values[:, :, :, 0:1],
bin_edges,
values[:, :, :, -1:]), dim=3) # Bx1xNxS+1
pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS
cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS
cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device),
cdf), dim=3) # Bx1xNxS+1
# uniform sample over the cdf values
u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples
# get the index on the cdf array
cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples
cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples
cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples
# linear approximation for each bin
cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2)
cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4)
bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2)
bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4)
# avoid zero cdf_intervals
cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples
u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples
# there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling
t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5)
t = torch.where(cdf_intervals <= 1e-4,
torch.full_like(u_cdf_lower, 0.5),
t)
samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals
return samples
|