import torch def transform_G_xyz(G, xyz, is_return_homo=False): """ :param G: Bx4x4 :param xyz: Bx3xN :return: """ assert len(G.size()) == len(xyz.size()) if len(G.size()) == 2: G_B44 = G.unsqueeze(0) xyz_B3N = xyz.unsqueeze(0) else: G_B44 = G xyz_B3N = xyz xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1) G_xyz_B4N = torch.matmul(G_B44, xyz_B4N) if is_return_homo: return G_xyz_B4N else: return G_xyz_B4N[:, 0:3, :] def gather_pixel_by_pxpy(img, pxpy): """ :param img: Bx3xHxW :param pxpy: Bx2xN :return: """ with torch.no_grad(): B, C, H, W = img.size() if pxpy.dtype == torch.float32: pxpy_int = torch.round(pxpy).to(torch.int64) pxpy_int = pxpy_int.to(torch.int64) pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1) pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1) pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] # Bx1xN_pt rgb = torch.gather(img.view(B, C, H * W), dim=2, index=pxpy_idx.repeat(1, C, 1)) # BxCxN_pt return rgb def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device): """ In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) :param start: :param end: :param num_bins: :return: """ assert disparity_np[0] > disparity_np[-1] S = disparity_np.shape[0] - 1 B = batch_size bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) # S+1 interval = bin_edges[1:] - bin_edges[0:-1] # S bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS interval = interval.unsqueeze(0).repeat(B, 1) # S -> BxS random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS disparity_array = bin_edges_start + interval * random_float return disparity_array # BxS def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device): """ In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) :param start: :param end: :param num_bins: :return: """ assert start > end B, S = batch_size, num_bins bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) # S+1 interval = bin_edges[1] - bin_edges[0] # scalar bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) # S -> BxS # bin_edges_end = bin_edges[1:].unsqueeze(0).repeat(B, 1) # S -> BxS random_float = torch.rand((B, S), dtype=torch.float32, device=device) # BxS disparity_array = bin_edges_start + interval * random_float return disparity_array # BxS def sample_pdf(values, weights, N_samples): """ draw samples from distribution approximated by values and weights. the probability distribution can be denoted as weights = p(values) :param values: Bx1xNxS :param weights: Bx1xNxS :param N_samples: number of sample to draw :return: """ B, N, S = weights.size(0), weights.size(2), weights.size(3) assert values.size() == (B, 1, N, S) # convert values to bin edges bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 # Bx1xNxS-1 bin_edges = torch.cat((values[:, :, :, 0:1], bin_edges, values[:, :, :, -1:]), dim=3) # Bx1xNxS+1 pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) # Bx1xNxS cdf = torch.cumsum(pdf, dim=3) # Bx1xNxS cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device), cdf), dim=3) # Bx1xNxS+1 # uniform sample over the cdf values u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) # Bx1xNxN_samples # get the index on the cdf array cdf_idx = torch.searchsorted(cdf, u, right=True) # Bx1xNxN_samples cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) # Bx1xNxN_samples cdf_idx_upper = torch.clamp(cdf_idx, max=S) # Bx1xNxN_samples # linear approximation for each bin cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) # Bx1xNx(N_samplesx2) cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2) cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4) bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) # Bx1xNx(N_samplesx2) bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4) # avoid zero cdf_intervals cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] # Bx1xNxN_samples u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] # Bx1xNxN_samples # there is the case that cdf_interval = 0, caused by the cdf_idx_lower/upper clamp above, need special handling t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5) t = torch.where(cdf_intervals <= 1e-4, torch.full_like(u_cdf_lower, 0.5), t) samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals return samples