paulpanwang's picture
Upload folder using huggingface_hub
7760d2d verified
from typing import *
from torch import Tensor
import os
import numpy as np
from plyfile import PlyData, PlyElement
import torch
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
class Camera:
def __init__(self,
C2W: Tensor, fxfycxcy: Tensor, h: int, w: int,
znear: float = 0.01, zfar: float = 100.,
):
self.fxfycxcy = fxfycxcy.clone().float()
self.C2W = C2W.clone().float()
self.W2C = self.C2W.inverse()
self.znear = znear
self.zfar = zfar
self.h = h
self.w = w
fx, fy, cx, cy = self.fxfycxcy[0], self.fxfycxcy[1], self.fxfycxcy[2], self.fxfycxcy[3]
self.tanfovX = 1 / (2 * fx) # `tanHalfFovX` actually
self.tanfovY = 1 / (2 * fy) # `tanHalfFovY` actually
self.fovX = 2 * torch.atan(self.tanfovX)
self.fovY = 2 * torch.atan(self.tanfovY)
self.shiftX = 2 * cx - 1
self.shiftY = 2 * cy - 1
def getProjectionMatrix(znear, zfar, fovX, fovY, shiftX, shiftY):
tanHalfFovY = torch.tan((fovY / 2))
tanHalfFovX = torch.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4, device=fovX.device)
z_sign = 1
P[0, 0] = 2 * znear / (right - left)
P[1, 1] = 2 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left) + shiftX
P[1, 2] = (top + bottom) / (top - bottom) + shiftY
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
self.world_view_transform = self.W2C.transpose(0, 1)
self.projection_matrix = getProjectionMatrix(self.znear, self.zfar, self.fovX, self.fovY, self.shiftX, self.shiftY).transpose(0, 1)
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
self.camera_center = self.C2W[:3, 3]
class GaussianModel:
def __init__(self):
self.xyz = None
self.rgb = None
self.scale = None
self.rotation = None
self.opacity = None
self.sh_degree = 0
def set_data(self, xyz: Tensor, rgb: Tensor, scale: Tensor, rotation: Tensor, opacity: Tensor):
self.xyz = xyz
self.rgb = rgb
self.scale = scale
self.rotation = rotation
self.opacity = opacity
return self
def to(self, device: torch.device = None, dtype: torch.dtype = None) -> "GaussianModel":
self.xyz = self.xyz.to(device, dtype)
self.rgb = self.rgb.to(device, dtype)
self.scale = self.scale.to(device, dtype)
self.rotation = self.rotation.to(device, dtype)
self.opacity = self.opacity.to(device, dtype)
return self
def save_ply(self, path: str, opacity_threshold: float = 0.):
os.makedirs(os.path.dirname(path), exist_ok=True)
xyz = self.xyz.detach().cpu().numpy()
f_dc = self.rgb.detach().cpu().numpy()
rgb = (f_dc * 255.).clip(0., 255.).astype(np.uint8)
opacity = self.opacity.detach().cpu().numpy()
scale = self.scale.detach().cpu().numpy()
rotation = self.rotation.detach().cpu().numpy()
# Filter out points with low opacity
mask = (opacity > opacity_threshold).squeeze()
xyz = xyz[mask]
f_dc = f_dc[mask]
opacity = opacity[mask]
scale = scale[mask]
rotation = rotation[mask]
rgb = rgb[mask]
dtype_full = [(attribute, "f4") for attribute in self._construct_list_of_attributes()]
dtype_full.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")])
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, f_dc, opacity, scale, rotation, rgb), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, "vertex")
PlyData([el]).write(path)
def load_ply(self, path: str):
plydata = PlyData.read(path)
xyz = np.stack((
np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"]),
), axis=1)
f_dc = np.stack((
np.asarray(plydata.elements[0]["f_dc_0"]),
np.asarray(plydata.elements[0]["f_dc_1"]),
np.asarray(plydata.elements[0]["f_dc_2"]),
), axis=1)
opacity = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
scale = np.stack((
np.asarray(plydata.elements[0]["scale_0"]),
np.asarray(plydata.elements[0]["scale_1"]),
np.asarray(plydata.elements[0]["scale_2"]),
), axis=1)
rotation = np.stack((
np.asarray(plydata.elements[0]["rot_0"]),
np.asarray(plydata.elements[0]["rot_1"]),
np.asarray(plydata.elements[0]["rot_2"]),
np.asarray(plydata.elements[0]["rot_3"]),
), axis=1)
self.xyz = torch.from_numpy(xyz).float()
self.rgb = torch.from_numpy(f_dc).float()
self.opacity = torch.from_numpy(opacity).float()
self.scale = torch.from_numpy(scale).float()
self.rotation = torch.from_numpy(rotation).float()
def _construct_list_of_attributes(self):
l = ["x", "y", "z"]
for i in range(self.rgb.shape[1]):
l.append(f"f_dc_{i}")
l.append("opacity")
for i in range(self.scale.shape[1]):
l.append(f"scale_{i}")
for i in range(self.rotation.shape[1]):
l.append(f"rot_{i}")
return l
def render(
pc: GaussianModel,
height: int,
width: int,
C2W: Tensor,
fxfycxcy: Tensor,
znear: float = 0.01,
zfar: float = 100.,
bg_color: Union[Tensor, Tuple[float, float, float]] = (1., 1., 1.),
scaling_modifier: float = 1.,
render_dn: bool = False,
):
viewpoint_camera = Camera(C2W, fxfycxcy, height, width, znear, zfar)
if not isinstance(bg_color, Tensor):
bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device)
else:
bg_color = bg_color.to(C2W.device, dtype=torch.float32)
pc = pc.to(dtype=torch.float32)
subpixel_offset = torch.zeros((int(viewpoint_camera.h), int(viewpoint_camera.w), 2), dtype=torch.float32, device="cuda")
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.h),
image_width=int(viewpoint_camera.w),
tanfovx=viewpoint_camera.tanfovX,
tanfovy=viewpoint_camera.tanfovY,
subpixel_offset=subpixel_offset,
kernel_size=0., # cf. Mip-Splatting; not used
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False,
)
alpha = normal = depth = None
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
# Rasterize visible Gaussians to image, obtain their radii (on screen).
image, radii = rasterizer( # not used: radii, coord, mcoord, mdepth
means3D=pc.xyz,
means2D=torch.zeros_like(pc.xyz, dtype=torch.float32, device=pc.xyz.device),
shs=None,
colors_precomp=pc.rgb,
opacities=pc.opacity,
scales=pc.scale,
rotations=pc.rotation,
cov3D_precomp=None,
)
return {
"image": image,
"alpha": alpha,
"depth": depth,
"normal": normal,
}