File size: 5,694 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7760d2d
 
 
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import *
from torch import Tensor

import torch
from einops import rearrange, repeat

from src.models.gs_render.deferred_bp import deferred_bp
from src.models.gs_render.gs_util import GaussianModel, render
from src.options import Options
from src.utils import unproject_depth


class GaussianRenderer:
    def __init__(self, opt: Options):
        self.opt = opt

        self.scale_activation = lambda x: \
            self.opt.scale_min * x + self.opt.scale_max * (1. - x)  # [0, 1] -> [s_min, s_max]

    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def render(self,
        model_outputs: Dict[str, Tensor],
        input_C2W: Tensor, input_fxfycxcy: Tensor,
        C2W: Tensor, fxfycxcy: Tensor,
        height: Optional[float] = None,
        width: Optional[float] = None,
        bg_color: Tuple[float, float, float] = (1., 1., 1.),
        scaling_modifier: float = 1.,
        opacity_threshold: float = 0.,
        input_normalized: bool = True,
        in_image_format: bool = True,
    ):
        if not in_image_format:
            assert height is not None and width is not None
            assert "xyz" in model_outputs  # depth must be in image format

        rgb, scale, rotation, opacity = model_outputs["rgb"], model_outputs["scale"], model_outputs["rotation"], model_outputs["opacity"]
        depth = model_outputs.get("depth", None)
        xyz = model_outputs.get("xyz", None)
        # Only one of `depth` and `xyz` should be None
        assert (depth is not None or xyz is not None) and not (depth is not None and xyz is not None)

        # Rendering resolution could be different from input resolution
        H = height if height is not None else rgb.shape[-2]
        W = width if width is not None else rgb.shape[-1]

        # Reshape for rendering
        if in_image_format:
            rgb = rearrange(rgb, "b v c h w -> b (v h w) c")
            scale = rearrange(scale, "b v c h w -> b (v h w) c")
            rotation = rearrange(rotation, "b v c h w -> b (v h w) c")
            opacity = rearrange(opacity, "b v c h w -> b (v h w) c")

        # Prepare XYZ for rendering
        if xyz is None:
            if input_normalized:
                depth = depth + torch.norm(input_C2W[:, :, :3, 3], p=2, dim=2, keepdim=True)[..., None, None]  # [-1, 1] -> image plane + [-1, 1]
            xyz = unproject_depth(depth.squeeze(2), input_C2W, input_fxfycxcy)  # [-1, 1]
        xyz = xyz + model_outputs.get("offset", torch.zeros_like(xyz))
        if in_image_format:
            xyz = rearrange(xyz, "b v c h w -> b (v h w) c")

        # From [-1, 1] to valid values
        if input_normalized:
            rgb = rgb * 0.5 + 0.5  # [-1, 1] -> [0, 1]
            scale = self.scale_activation(scale * 0.5 + 0.5)  # [-1, 1] -> [0, 1] -> [s_min, s_max]
            rotation = rotation  # not changed; already L2 normalized
            opacity = opacity * 0.5 + 0.5  # [-1, 1] -> [0, 1]

        # Filter by opacity
        opacity = (opacity > opacity_threshold) * opacity

        (B, V), device = C2W.shape[:2], C2W.device  # `HR`/`WR` meight be different from `H`/`W`
        images = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device)
        alphas = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device)
        depths = torch.zeros(B, V, 1, H, W, dtype=torch.float32, device=device)
        normals = torch.zeros(B, V, 3, H, W, dtype=torch.float32, device=device)

        pcs = []
        for i in range(B):
            pcs.append(GaussianModel().set_data(xyz[i], rgb[i], scale[i], rotation[i], opacity[i]))

        if self.opt.render_type == "defered":
            images, alphas, depths, normals = deferred_bp(
                xyz, rgb, scale, rotation, opacity,
                H, W, C2W, fxfycxcy,
                self.opt.deferred_bp_patch_size, GaussianModel(),
                self.opt.znear, self.opt.zfar,
                bg_color,
                scaling_modifier,
                self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \
                    self.opt.vis_coords or self.opt.vis_normals,  # whether render depth & normal
            )
        else:  # default
            for i in range(B):
                pc = pcs[i]
                for j in range(V):
                    render_results = render(
                        pc, H, W, C2W[i, j], fxfycxcy[i, j],
                        self.opt.znear, self.opt.zfar,
                        bg_color,
                        scaling_modifier,
                        self.opt.coord_weight > 0. or self.opt.normal_weight > 0. or \
                            self.opt.vis_coords or self.opt.vis_normals,  # whether render depth & normal
                    )
                    images[i, j] = render_results["image"]
                    # alphas[i, j] = render_results["alpha"]
                    # depths[i, j] = render_results["depth"]
                    # normals[i, j] = render_results["normal"]

        if not isinstance(bg_color, Tensor):
            bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=device)
        bg_color = repeat(bg_color, "c -> b v c h w", b=B, v=V, h=H, w=W)

        coords = (unproject_depth(depths.squeeze(2), C2W, fxfycxcy)
            * 0.5 + 0.5) * alphas + (1. - alphas) * bg_color
        normals_ = (torch.einsum("bvrc,bvchw->bvrhw", C2W[:, :, :3, :3], normals)
            * 0.5 + 0.5) * alphas + (1. - alphas) * bg_color

        return {
            "image": images,
            "alpha": alphas,
            "coord": coords,
            "normal": normals_,
            "raw_depth": depths,
            "raw_normal": normals,
            "pc": pcs,
        }