Diffsplat / src /models /gs_render /gs_renderer.py
paulpanwang's picture
Upload folder using huggingface_hub
7760d2d verified
from typing import *
from torch import Tensor
import torch
from einops import rearrange, repeat
from src.models.gs_render.deferred_bp import deferred_bp
from src.models.gs_render.gs_util import GaussianModel, render
from src.options import Options
from src.utils import unproject_depth
class GaussianRenderer:
def __init__(self, opt: Options):
self.opt = opt
self.scale_activation = lambda x: \
self.opt.scale_min * x + self.opt.scale_max * (1. - x) # [0, 1] -> [s_min, s_max]
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def render(self,
model_outputs: Dict[str, Tensor],
input_C2W: Tensor, input_fxfycxcy: Tensor,
C2W: Tensor, fxfycxcy: Tensor,
height: Optional[float] = None,
width: Optional[float] = None,
bg_color: Tuple[float, float, float] = (1., 1., 1.),
scaling_modifier: float = 1.,
opacity_threshold: float = 0.,
input_normalized: bool = True,
in_image_format: bool = True,
):
if not in_image_format:
assert height is not None and width is not None
assert "xyz" in model_outputs # depth must be in image format
rgb, scale, rotation, opacity = model_outputs["rgb"], model_outputs["scale"], model_outputs["rotation"], model_outputs["opacity"]
depth = model_outputs.get("depth", None)
xyz = model_outputs.get("xyz", None)
# Only one of `depth` and `xyz` should be None
assert (depth is not None or xyz is not None) and not (depth is not None and xyz is not None)
# Rendering resolution could be different from input resolution
H = height if height is not None else rgb.shape[-2]
W = width if width is not None else rgb.shape[-1]
# Reshape for rendering
if in_image_format:
rgb = rearrange(rgb, "b v c h w -> b (v h w) c")
scale = rearrange(scale, "b v c h w -> b (v h w) c")
rotation = rearrange(rotation, "b v c h w -> b (v h w) c")
opacity = rearrange(opacity, "b v c h w -> b (v h w) c")
# Prepare XYZ for rendering
if xyz is None:
if input_normalized:
depth = depth + torch.norm(input_C2W[:, :, :3, 3], p=2, dim=2, keepdim=True)[..., None, None] # [-1, 1] -> image plane + [-1, 1]
xyz = unproject_depth(depth.squeeze(2), input_C2W, input_fxfycxcy) # [-1, 1]
xyz = xyz + model_outputs.get("offset", torch.zeros_like(xyz))
if in_image_format:
xyz = rearrange(xyz, "b v c h w -> b (v h w) c")
# From [-1, 1] to valid values
if input_normalized:
rgb = rgb * 0.5 + 0.5 # [-1, 1] -> [0, 1]
scale = self.scale_activation(scale * 0.5 + 0.5) # [-1, 1] -> [0, 1] -> [s_min, s_max]
rotation = rotation # not changed; already L2 normalized
opacity = opacity * 0.5 + 0.5 # [-1, 1] -> [0, 1]
# Filter by opacity
opacity = (opacity > opacity_threshold) * opacity
(B, V), device = C2W.shape[:2], C2W.device # `HR`/`WR` meight be different from `H`/`W`
images = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device)
alphas = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device)
depths = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device)
normals = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device)
pcs = []
for i in range(B):
pcs.append(GaussianModel().set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i]))
if self.opt.render_type == "defered":
images, alphas, depths, normals = deferred_bp(
xyz, rgb, scale, rotation, opacity,
H, W, C2W, fxfycxcy,
self.opt.deferred_bp_patch_size, GaussianModel(),
self.opt.znear, self.opt.zfar,
bg_color,
scaling_modifier,
self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \
self.opt.vis_coords or self.opt.vis_normals, # whether render depth & normal
)
else: # default
for i in range(B):
pc = pcs[i]
for j in range(V):
render_results = render(
pc, H, W, C2W[i, j], fxfycxcy[i, j],
self.opt.znear, self.opt.zfar,
bg_color,
scaling_modifier,
self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \
self.opt.vis_coords or self.opt.vis_normals, # whether render depth & normal
)
images[i, j] = render_results["image"]
# alphas[i, j] = render_results["alpha"]
# depths[i, j] = render_results["depth"]
# normals[i, j] = render_results["normal"]
if not isinstance(bg_color, Tensor):
bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=device)
bg_color = repeat(bg_color, "c -> b v c h w", b=B, v=V, h=H, w=W)
coords = (unproject_depth(depths.squeeze(2), C2W, fxfycxcy)
* 0.5 + 0.5) * alphas + (1. - alphas) * bg_color
normals_ = (torch.einsum("bvrc,bvchw->bvrhw", C2W[:, :, :3, :3], normals)
* 0.5 + 0.5) * alphas + (1. - alphas) * bg_color
return {
"image": images,
"alpha": alphas,
"coord": coords,
"normal": normals_,
"raw_depth": depths,
"raw_normal": normals,
"pc": pcs,
}