Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from torch import Tensor | |
import torch | |
from einops import rearrange, repeat | |
from src.models.gs_render.deferred_bp import deferred_bp | |
from src.models.gs_render.gs_util import GaussianModel, render | |
from src.options import Options | |
from src.utils import unproject_depth | |
class GaussianRenderer: | |
def __init__(self, opt: Options): | |
self.opt = opt | |
self.scale_activation = lambda x: \ | |
self.opt.scale_min * x + self.opt.scale_max * (1. - x) # [0, 1] -> [s_min, s_max] | |
def render(self, | |
model_outputs: Dict[str, Tensor], | |
input_C2W: Tensor, input_fxfycxcy: Tensor, | |
C2W: Tensor, fxfycxcy: Tensor, | |
height: Optional[float] = None, | |
width: Optional[float] = None, | |
bg_color: Tuple[float, float, float] = (1., 1., 1.), | |
scaling_modifier: float = 1., | |
opacity_threshold: float = 0., | |
input_normalized: bool = True, | |
in_image_format: bool = True, | |
): | |
if not in_image_format: | |
assert height is not None and width is not None | |
assert "xyz" in model_outputs # depth must be in image format | |
rgb, scale, rotation, opacity = model_outputs["rgb"], model_outputs["scale"], model_outputs["rotation"], model_outputs["opacity"] | |
depth = model_outputs.get("depth", None) | |
xyz = model_outputs.get("xyz", None) | |
# Only one of `depth` and `xyz` should be None | |
assert (depth is not None or xyz is not None) and not (depth is not None and xyz is not None) | |
# Rendering resolution could be different from input resolution | |
H = height if height is not None else rgb.shape[-2] | |
W = width if width is not None else rgb.shape[-1] | |
# Reshape for rendering | |
if in_image_format: | |
rgb = rearrange(rgb, "b v c h w -> b (v h w) c") | |
scale = rearrange(scale, "b v c h w -> b (v h w) c") | |
rotation = rearrange(rotation, "b v c h w -> b (v h w) c") | |
opacity = rearrange(opacity, "b v c h w -> b (v h w) c") | |
# Prepare XYZ for rendering | |
if xyz is None: | |
if input_normalized: | |
depth = depth + torch.norm(input_C2W[:, :, :3, 3], p=2, dim=2, keepdim=True)[..., None, None] # [-1, 1] -> image plane + [-1, 1] | |
xyz = unproject_depth(depth.squeeze(2), input_C2W, input_fxfycxcy) # [-1, 1] | |
xyz = xyz + model_outputs.get("offset", torch.zeros_like(xyz)) | |
if in_image_format: | |
xyz = rearrange(xyz, "b v c h w -> b (v h w) c") | |
# From [-1, 1] to valid values | |
if input_normalized: | |
rgb = rgb * 0.5 + 0.5 # [-1, 1] -> [0, 1] | |
scale = self.scale_activation(scale * 0.5 + 0.5) # [-1, 1] -> [0, 1] -> [s_min, s_max] | |
rotation = rotation # not changed; already L2 normalized | |
opacity = opacity * 0.5 + 0.5 # [-1, 1] -> [0, 1] | |
# Filter by opacity | |
opacity = (opacity > opacity_threshold) * opacity | |
(B, V), device = C2W.shape[:2], C2W.device # `HR`/`WR` meight be different from `H`/`W` | |
images = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device) | |
alphas = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device) | |
depths = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device) | |
normals = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device) | |
pcs = [] | |
for i in range(B): | |
pcs.append(GaussianModel().set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i])) | |
if self.opt.render_type == "defered": | |
images, alphas, depths, normals = deferred_bp( | |
xyz, rgb, scale, rotation, opacity, | |
H, W, C2W, fxfycxcy, | |
self.opt.deferred_bp_patch_size, GaussianModel(), | |
self.opt.znear, self.opt.zfar, | |
bg_color, | |
scaling_modifier, | |
self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \ | |
self.opt.vis_coords or self.opt.vis_normals, # whether render depth & normal | |
) | |
else: # default | |
for i in range(B): | |
pc = pcs[i] | |
for j in range(V): | |
render_results = render( | |
pc, H, W, C2W[i, j], fxfycxcy[i, j], | |
self.opt.znear, self.opt.zfar, | |
bg_color, | |
scaling_modifier, | |
self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \ | |
self.opt.vis_coords or self.opt.vis_normals, # whether render depth & normal | |
) | |
images[i, j] = render_results["image"] | |
# alphas[i, j] = render_results["alpha"] | |
# depths[i, j] = render_results["depth"] | |
# normals[i, j] = render_results["normal"] | |
if not isinstance(bg_color, Tensor): | |
bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=device) | |
bg_color = repeat(bg_color, "c -> b v c h w", b=B, v=V, h=H, w=W) | |
coords = (unproject_depth(depths.squeeze(2), C2W, fxfycxcy) | |
* 0.5 + 0.5) * alphas + (1. - alphas) * bg_color | |
normals_ = (torch.einsum("bvrc,bvchw->bvrhw", C2W[:, :, :3, :3], normals) | |
* 0.5 + 0.5) * alphas + (1. - alphas) * bg_color | |
return { | |
"image": images, | |
"alpha": alphas, | |
"coord": coords, | |
"normal": normals_, | |
"raw_depth": depths, | |
"raw_normal": normals, | |
"pc": pcs, | |
} | |