import torch import argparse import os import numpy as np from lightning_fabric import seed_everything from tqdm import tqdm import random import warnings from scipy.stats import entropy from sklearn.neighbors import NearestNeighbors from plyfile import PlyData from pathlib import Path import multiprocessing from chamfer_distance import ChamferDistance from eval.eval_pc_set import * N_POINTS = 2000 def find_files(folder, extension): return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)]) def read_ply(path): with open(path, 'rb') as f: plydata = PlyData.read(f) x = np.array(plydata['vertex']['x']) y = np.array(plydata['vertex']['y']) z = np.array(plydata['vertex']['z']) vertex = np.stack([x, y, z], axis=1) return vertex def distChamfer(a, b): x, y = a, b bs, num_points, points_dim = x.size() xx = torch.bmm(x, x.transpose(2, 1)) yy = torch.bmm(y, y.transpose(2, 1)) zz = torch.bmm(x, y.transpose(2, 1)) diag_ind = torch.arange(0, num_points).to(a).long() rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) P = (rx.transpose(2, 1) + ry - 2 * zz) return P.min(1)[0], P.min(2)[0] def _pairwise_CD(sample_pcs, ref_pcs, batch_size): N_sample = sample_pcs.shape[0] N_ref = ref_pcs.shape[0] all_cd = [] all_emd = [] iterator = range(N_sample) matched_gt = [] pbar = tqdm(iterator) chamfer_dist = ChamferDistance() for sample_b_start in pbar: sample_batch = sample_pcs[sample_b_start] cd_lst = [] emd_lst = [] for ref_b_start in range(0, N_ref, batch_size): ref_b_end = min(N_ref, ref_b_start + batch_size) ref_batch = ref_pcs[ref_b_start:ref_b_end] batch_size_ref = ref_batch.size(0) sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1) sample_batch_exp = sample_batch_exp.contiguous() dl, dr, idx1, idx2 = chamfer_dist(sample_batch_exp, ref_batch) cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) cd_lst = torch.cat(cd_lst, dim=1) all_cd.append(cd_lst) hit = np.argmin(cd_lst.detach().cpu().numpy()[0]) matched_gt.append(hit) pbar.set_postfix({"cov": len(np.unique(matched_gt)) * 1.0 / N_ref}) all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref return all_cd def compute_cov_mmd(sample_pcs, ref_pcs, batch_size): all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size) N_sample, N_ref = all_dist.size(0), all_dist.size(1) min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) min_val, _ = torch.min(all_dist, dim=0) mmd = min_val.mean() cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) cov = torch.tensor(cov).to(all_dist) return { 'MMD-CD': mmd.item(), 'COV-CD': cov.item(), }, min_idx.cpu().numpy() def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28): '''Computes the JSD between two sets of point-clouds, as introduced in the paper ```Learning Representations And Generative Models For 3D Point Clouds```. Args: sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. resolution: (int) grid-resolution. Affects granularity of measurements. ''' sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] return jensen_shannon_divergence(sample_grid_var, ref_grid_var) def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): '''Given a collection of point-clouds, estimate the entropy of the random variables corresponding to occupancy-grid activation patterns. Inputs: pclouds: (numpy array) #point-clouds x points per point-cloud x 3 grid_resolution (int) size of occupancy grid that will be used. ''' epsilon = 10e-4 bound = 1 + epsilon if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: print(abs(np.max(pclouds)), abs(np.min(pclouds))) warnings.warn('Point-clouds are not in unit cube.') if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: warnings.warn('Point-clouds are not in unit sphere.') grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) grid_coordinates = grid_coordinates.reshape(-1, 3) grid_counters = np.zeros(len(grid_coordinates)) grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) for pc in pclouds: _, indices = nn.kneighbors(pc) indices = np.squeeze(indices) for i in indices: grid_counters[i] += 1 indices = np.unique(indices) for i in indices: grid_bernoulli_rvars[i] += 1 acc_entropy = 0.0 n = float(len(pclouds)) for g in grid_bernoulli_rvars: p = 0.0 if g > 0: p = float(g) / n acc_entropy += entropy([p, 1.0 - p]) return acc_entropy / len(grid_counters), grid_counters def unit_cube_grid_point_cloud(resolution, clip_sphere=False): '''Returns the center coordinates of each cell of a 3D grid with resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere. ''' grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) spacing = 1.0 / float(resolution - 1) * 2 for i in range(resolution): for j in range(resolution): for k in range(resolution): grid[i, j, k, 0] = i * spacing - 0.5 * 2 grid[i, j, k, 1] = j * spacing - 0.5 * 2 grid[i, j, k, 2] = k * spacing - 0.5 * 2 if clip_sphere: grid = grid.reshape(-1, 3) grid = grid[np.linalg.norm(grid, axis=1) <= 0.5] return grid, spacing def jensen_shannon_divergence(P, Q): if np.any(P < 0) or np.any(Q < 0): raise ValueError('Negative values.') if len(P) != len(Q): raise ValueError('Non equal size.') P_ = P / np.sum(P) # Ensure probabilities. Q_ = Q / np.sum(Q) e1 = entropy(P_, base=2) e2 = entropy(Q_, base=2) e_sum = entropy((P_ + Q_) / 2.0, base=2) res = e_sum - ((e1 + e2) / 2.0) res2 = _jsdiv(P_, Q_) if not np.allclose(res, res2, atol=10e-5, rtol=0): warnings.warn('Numerical values of two JSD methods don\'t agree.') return res def _jsdiv(P, Q): '''another way of computing JSD''' def _kldiv(A, B): a = A.copy() b = B.copy() idx = np.logical_and(a > 0, b > 0) a = a[idx] b = b[idx] return np.sum([v for v in a * np.log2(a / b)]) P_ = P / np.sum(P) Q_ = Q / np.sum(Q) M = 0.5 * (P_ + Q_) return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) def downsample_pc(points, n): sample_idx = random.sample(list(range(points.shape[0])), n) return points[sample_idx] def normalize_pc(points): # normalize mean = np.mean(points, axis=0) points = (points - mean) # fit to unit cube scale = np.max(np.abs(points)) points = points / scale return points def align_pc(points): # 1. Center the point cloud centroid = np.mean(points, axis=0) centered_points = points - centroid # 2. Calculate the three edge lengths of bbox min_coords = np.min(centered_points, axis=0) max_coords = np.max(centered_points, axis=0) dimensions = max_coords - min_coords # 3. Sort axes by dimension length to get axis order axis_order = np.argsort(dimensions)[::-1] # sort from longest to shortest # 4. Create permutation matrix (align longest edge to x, shortest to y) perm_matrix = np.zeros((3, 3)) perm_matrix[0, axis_order[0]] = 1 # longest edge -> x perm_matrix[1, axis_order[2]] = 1 # shortest edge -> y perm_matrix[2, axis_order[1]] = 1 # medium edge -> z # 5. Apply transformation aligned_points = np.dot(centered_points, perm_matrix.T) # 6. Ensure same centroid faces direction if np.mean(aligned_points[:, 2]) < 0: aligned_points[:, 2] *= -1 return aligned_points def collect_pc(cad_folder): pc_path = find_files(os.path.join(cad_folder, 'pcd'), 'final_pcd.ply') if len(pc_path) == 0: return [] pc_path = pc_path[-1] # final pcd pc = read_ply(pc_path) if pc.shape[0] > N_POINTS: pc = downsample_pc(pc, N_POINTS) pc = normalize_pc(pc) return pc def collect_pc2(cad_folder): pc = read_ply(cad_folder) if pc.shape[0] > N_POINTS: pc = downsample_pc(pc, N_POINTS) pc = normalize_pc(pc) pc = align_pc(pc) return pc theta_x = np.radians(90) # Rotation angle around X-axis theta_y = np.radians(90) # Rotation angle around Y-axis theta_z = np.radians(180) # Rotation angle around Z-axis # Create individual rotation matrices Rx = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)]]) Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)], [0, 1, 0], [-np.sin(theta_y), 0, np.cos(theta_y)]]) Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1]]) rotation_matrix = np.dot(np.dot(Rz, Ry), Rx) def collect_pc3(cad_folder): pc = read_ply(cad_folder) if pc.shape[0] > N_POINTS: pc = downsample_pc(pc, N_POINTS) pc = normalize_pc(pc) rotated_point_cloud = np.dot(pc, rotation_matrix.T).astype(np.float32) # Transpose the rotation matrix to apply it correctly return rotated_point_cloud def load_data_with_prefix(root_folder, prefix): data_files = [] # Walk through the directory tree starting from the root folder for root, dirs, files in os.walk(root_folder): for filename in files: # Check if the file ends with the specified prefix if filename.endswith(prefix): file_path = os.path.join(root, filename) data_files.append(file_path) data_files.sort() return data_files def main(): parser = argparse.ArgumentParser() parser.add_argument("--fake", type=str) parser.add_argument("--real", type=str) parser.add_argument("--n_test", type=int, default=1000) parser.add_argument("--multi", type=float, default=3) parser.add_argument("--times", type=int, default=10) parser.add_argument("--batch_size", type=int, default=64) args = parser.parse_args() seed_everything(0) print("n_test: {}, multiplier: {}, repeat times: {}".format(args.n_test, args.multi, args.times)) args.output = args.fake + '_results.txt' seed_everything(0) # Load reference pcd num_cpus = multiprocessing.cpu_count() ref_pcs = [] gt_shape_paths = load_data_with_prefix(args.real, '.ply') load_iter = multiprocessing.Pool(num_cpus).imap(collect_pc2, gt_shape_paths) for pc in tqdm(load_iter, total=len(gt_shape_paths)): if len(pc) > 0: ref_pcs.append(pc) ref_pcs = np.stack(ref_pcs, axis=0) print("real point clouds: {}".format(ref_pcs.shape)) # Load fake pcd sample_pcs = [] shape_paths = load_data_with_prefix(args.fake, '.ply') load_iter = multiprocessing.Pool(num_cpus).imap(collect_pc2, shape_paths) for pc in tqdm(load_iter, total=len(shape_paths)): if len(pc) > 0: sample_pcs.append(pc) sample_pcs = np.stack(sample_pcs, axis=0) print("fake point clouds: {}".format(sample_pcs.shape)) # Testing cov_on_gt = [] fp = open(args.output, "w") result_list = [] for i in range(args.times): print("iteration {}...".format(i)) select_idx1 = random.sample(list(range(len(sample_pcs))), int(args.multi * args.n_test)) rand_sample_pcs = sample_pcs[select_idx1] select_idx2 = random.sample(list(range(len(ref_pcs))), args.n_test) rand_ref_pcs = ref_pcs[select_idx2] jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False) with torch.no_grad(): rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda().float() rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda().float() result, idx = compute_cov_mmd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size) result.update({"JSD": jsd}) cov_on_gt.extend(list(np.array(select_idx2)[np.unique(idx)])) if False: unique_idx = np.unique(idx, return_counts=True) id_gts = unique_idx[0][np.argsort(unique_idx[1])[::-1][:100]] gt_prefixes = [os.path.basename(gt_shape_paths[i])[:8] for i in select_idx2] pred_prefixes = [os.path.basename(shape_paths[i])[:8] for i in select_idx1] gt_prefixes[403] print(result) print(result, file=fp) result_list.append(result) avg_result = {} for k in result_list[0].keys(): avg_result.update({"avg-" + k: np.mean([x[k] for x in result_list])}) print("average result:") print(avg_result) print(avg_result, file=fp) fp.close() cov_on_gt = list(set(cov_on_gt)) cov_on_gt = [gt_shape_paths[i] for i in cov_on_gt] np.save(args.fake + '_cov_on_gt.npy', cov_on_gt) if __name__ == '__main__': main()