Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from torch import Tensor | |
import os | |
import numpy as np | |
from plyfile import PlyData, PlyElement | |
import torch | |
from diff_gaussian_rasterization import ( | |
GaussianRasterizationSettings, | |
GaussianRasterizer, | |
) | |
class Camera: | |
def __init__(self, | |
C2W: Tensor, fxfycxcy: Tensor, h: int, w: int, | |
znear: float = 0.01, zfar: float = 100., | |
): | |
self.fxfycxcy = fxfycxcy.clone().float() | |
self.C2W = C2W.clone().float() | |
self.W2C = self.C2W.inverse() | |
self.znear = znear | |
self.zfar = zfar | |
self.h = h | |
self.w = w | |
fx, fy, cx, cy = self.fxfycxcy[0], self.fxfycxcy[1], self.fxfycxcy[2], self.fxfycxcy[3] | |
self.tanfovX = 1 / (2 * fx) # `tanHalfFovX` actually | |
self.tanfovY = 1 / (2 * fy) # `tanHalfFovY` actually | |
self.fovX = 2 * torch.atan(self.tanfovX) | |
self.fovY = 2 * torch.atan(self.tanfovY) | |
self.shiftX = 2 * cx - 1 | |
self.shiftY = 2 * cy - 1 | |
def getProjectionMatrix(znear, zfar, fovX, fovY, shiftX, shiftY): | |
tanHalfFovY = torch.tan((fovY / 2)) | |
tanHalfFovX = torch.tan((fovX / 2)) | |
top = tanHalfFovY * znear | |
bottom = -top | |
right = tanHalfFovX * znear | |
left = -right | |
P = torch.zeros(4, 4, device=fovX.device) | |
z_sign = 1 | |
P[0, 0] = 2 * znear / (right - left) | |
P[1, 1] = 2 * znear / (top - bottom) | |
P[0, 2] = (right + left) / (right - left) + shiftX | |
P[1, 2] = (top + bottom) / (top - bottom) + shiftY | |
P[3, 2] = z_sign | |
P[2, 2] = z_sign * zfar / (zfar - znear) | |
P[2, 3] = -(zfar * znear) / (zfar - znear) | |
return P | |
self.world_view_transform = self.W2C.transpose(0, 1) | |
self.projection_matrix = getProjectionMatrix(self.znear, self.zfar, self.fovX, self.fovY, self.shiftX, self.shiftY).transpose(0, 1) | |
self.full_proj_transform = self.world_view_transform @ self.projection_matrix | |
self.camera_center = self.C2W[:3, 3] | |
class GaussianModel: | |
def __init__(self): | |
self.xyz = None | |
self.rgb = None | |
self.scale = None | |
self.rotation = None | |
self.opacity = None | |
self.sh_degree = 0 | |
def set_data(self, xyz: Tensor, rgb: Tensor, scale: Tensor, rotation: Tensor, opacity: Tensor): | |
self.xyz = xyz | |
self.rgb = rgb | |
self.scale = scale | |
self.rotation = rotation | |
self.opacity = opacity | |
return self | |
def to(self, device: torch.device = None, dtype: torch.dtype = None) -> "GaussianModel": | |
self.xyz = self.xyz.to(device, dtype) | |
self.rgb = self.rgb.to(device, dtype) | |
self.scale = self.scale.to(device, dtype) | |
self.rotation = self.rotation.to(device, dtype) | |
self.opacity = self.opacity.to(device, dtype) | |
return self | |
def save_ply(self, path: str, opacity_threshold: float = 0.): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
xyz = self.xyz.detach().cpu().numpy() | |
f_dc = self.rgb.detach().cpu().numpy() | |
rgb = (f_dc * 255.).clip(0., 255.).astype(np.uint8) | |
opacity = self.opacity.detach().cpu().numpy() | |
scale = self.scale.detach().cpu().numpy() | |
rotation = self.rotation.detach().cpu().numpy() | |
# Filter out points with low opacity | |
mask = (opacity > opacity_threshold).squeeze() | |
xyz = xyz[mask] | |
f_dc = f_dc[mask] | |
opacity = opacity[mask] | |
scale = scale[mask] | |
rotation = rotation[mask] | |
rgb = rgb[mask] | |
dtype_full = [(attribute, "f4") for attribute in self._construct_list_of_attributes()] | |
dtype_full.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")]) | |
elements = np.empty(xyz.shape[0], dtype=dtype_full) | |
attributes = np.concatenate((xyz, f_dc, opacity, scale, rotation, rgb), axis=1) | |
elements[:] = list(map(tuple, attributes)) | |
el = PlyElement.describe(elements, "vertex") | |
PlyData([el]).write(path) | |
def load_ply(self, path: str): | |
plydata = PlyData.read(path) | |
xyz = np.stack(( | |
np.asarray(plydata.elements[0]["x"]), | |
np.asarray(plydata.elements[0]["y"]), | |
np.asarray(plydata.elements[0]["z"]), | |
), axis=1) | |
f_dc = np.stack(( | |
np.asarray(plydata.elements[0]["f_dc_0"]), | |
np.asarray(plydata.elements[0]["f_dc_1"]), | |
np.asarray(plydata.elements[0]["f_dc_2"]), | |
), axis=1) | |
opacity = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] | |
scale = np.stack(( | |
np.asarray(plydata.elements[0]["scale_0"]), | |
np.asarray(plydata.elements[0]["scale_1"]), | |
np.asarray(plydata.elements[0]["scale_2"]), | |
), axis=1) | |
rotation = np.stack(( | |
np.asarray(plydata.elements[0]["rot_0"]), | |
np.asarray(plydata.elements[0]["rot_1"]), | |
np.asarray(plydata.elements[0]["rot_2"]), | |
np.asarray(plydata.elements[0]["rot_3"]), | |
), axis=1) | |
self.xyz = torch.from_numpy(xyz).float() | |
self.rgb = torch.from_numpy(f_dc).float() | |
self.opacity = torch.from_numpy(opacity).float() | |
self.scale = torch.from_numpy(scale).float() | |
self.rotation = torch.from_numpy(rotation).float() | |
def _construct_list_of_attributes(self): | |
l = ["x", "y", "z"] | |
for i in range(self.rgb.shape[1]): | |
l.append(f"f_dc_{i}") | |
l.append("opacity") | |
for i in range(self.scale.shape[1]): | |
l.append(f"scale_{i}") | |
for i in range(self.rotation.shape[1]): | |
l.append(f"rot_{i}") | |
return l | |
def render( | |
pc: GaussianModel, | |
height: int, | |
width: int, | |
C2W: Tensor, | |
fxfycxcy: Tensor, | |
znear: float = 0.01, | |
zfar: float = 100., | |
bg_color: Union[Tensor, Tuple[float, float, float]] = (1., 1., 1.), | |
scaling_modifier: float = 1., | |
render_dn: bool = False, | |
): | |
viewpoint_camera = Camera(C2W, fxfycxcy, height, width, znear, zfar) | |
if not isinstance(bg_color, Tensor): | |
bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device) | |
else: | |
bg_color = bg_color.to(C2W.device, dtype=torch.float32) | |
pc = pc.to(dtype=torch.float32) | |
subpixel_offset = torch.zeros((int(viewpoint_camera.h), int(viewpoint_camera.w), 2), dtype=torch.float32, device="cuda") | |
raster_settings = GaussianRasterizationSettings( | |
image_height=int(viewpoint_camera.h), | |
image_width=int(viewpoint_camera.w), | |
tanfovx=viewpoint_camera.tanfovX, | |
tanfovy=viewpoint_camera.tanfovY, | |
subpixel_offset=subpixel_offset, | |
kernel_size=0., # cf. Mip-Splatting; not used | |
bg=bg_color, | |
scale_modifier=scaling_modifier, | |
viewmatrix=viewpoint_camera.world_view_transform, | |
projmatrix=viewpoint_camera.full_proj_transform, | |
sh_degree=pc.sh_degree, | |
campos=viewpoint_camera.camera_center, | |
prefiltered=False, | |
debug=False, | |
) | |
alpha = normal = depth = None | |
rasterizer = GaussianRasterizer(raster_settings=raster_settings) | |
# Rasterize visible Gaussians to image, obtain their radii (on screen). | |
image, radii = rasterizer( # not used: radii, coord, mcoord, mdepth | |
means3D=pc.xyz, | |
means2D=torch.zeros_like(pc.xyz, dtype=torch.float32, device=pc.xyz.device), | |
shs=None, | |
colors_precomp=pc.rgb, | |
opacities=pc.opacity, | |
scales=pc.scale, | |
rotations=pc.rotation, | |
cov3D_precomp=None, | |
) | |
return { | |
"image": image, | |
"alpha": alpha, | |
"depth": depth, | |
"normal": normal, | |
} | |