Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,805 Bytes
476e0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# copy from https://github.com/autonomousvision/gaussian-opacity-fields/blob/main/utils/vis_utils.py
# copy from nerfstudio and 2DGS
import torch
from matplotlib import cm
import open3d as o3d
import matplotlib.pyplot as plt
import numpy as np
def apply_colormap(image, cmap="viridis"):
colormap = cm.get_cmap(cmap)
colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore
image_long = (image * 255).long()
image_long_min = torch.min(image_long)
image_long_max = torch.max(image_long)
assert image_long_min >= 0, f"the min value is {image_long_min}"
assert image_long_max <= 255, f"the max value is {image_long_max}"
return colormap[image_long[..., 0]]
def apply_depth_colormap(
depth,
accumulation,
near_plane = 2.0,
far_plane = 4.0,
cmap="turbo",
):
# near_plane = near_plane or float(torch.min(depth))
# far_plane = far_plane or float(torch.max(depth))
near_plane = near_plane
far_plane = far_plane
depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
depth = torch.clip(depth, 0, 1)
# depth = torch.nan_to_num(depth, nan=0.0) # TODO(ethan): remove this
colored_image = apply_colormap(depth, cmap=cmap)
if accumulation is not None:
colored_image = colored_image * accumulation + (1 - accumulation)
return colored_image
def save_points(path_save, pts, colors=None, normals=None, BRG2RGB=False):
"""save points to point cloud using open3d"""
assert len(pts) > 0
if colors is not None:
assert colors.shape[1] == 3
assert pts.shape[1] == 3
cloud = o3d.geometry.PointCloud()
cloud.points = o3d.utility.Vector3dVector(pts)
if colors is not None:
# Open3D assumes the color values are of float type and in range [0, 1]
if np.max(colors) > 1:
colors = colors / np.max(colors)
if BRG2RGB:
colors = np.stack([colors[:, 2], colors[:, 1], colors[:, 0]], axis=-1)
cloud.colors = o3d.utility.Vector3dVector(colors)
if normals is not None:
cloud.normals = o3d.utility.Vector3dVector(normals)
o3d.io.write_point_cloud(path_save, cloud)
def colormap(img, cmap='jet'):
W, H = img.shape[:2]
dpi = 300
fig, ax = plt.subplots(1, figsize=(H/dpi, W/dpi), dpi=dpi)
im = ax.imshow(img, cmap=cmap)
ax.set_axis_off()
fig.colorbar(im, ax=ax)
fig.tight_layout()
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
img = torch.from_numpy(data / 255.).float().permute(2,0,1)
plt.close()
if img.shape[1:] != (H, W):
img = torch.nn.functional.interpolate(img[None], (W, H), mode='bilinear', align_corners=False)[0]
return img |