File size: 8,532 Bytes
476e0f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from typing import *
from torch import Tensor

import torch
import torch.nn.functional as tF


def normalize_normals(normals: Tensor, C2W: Tensor, i: int = 0) -> Tensor:
    """Normalize a batch of multi-view `normals` by the `i`-th view.

    Inputs:
        - `normals`: (B, V, 3, H, W)
        - `C2W`: (B, V, 4, 4)
        - `i`: the index of the view to normalize by

    Outputs:
        - `normalized_normals`: (B, V, 3, H, W)
    """
    _, _, R, C = C2W.shape  # (B, V, 4, 4)
    assert R == C == 4
    _, _, CC, _, _ = normals.shape  # (B, V, 3, H, W)
    assert CC == 3

    dtype = normals.dtype
    normals = normals.clone().float()
    transform = torch.inverse(C2W[:, i, :3, :3])  # (B, 3, 3)

    return torch.einsum("brc,bvchw->bvrhw", transform, normals).to(dtype)  # (B, V, 3, H, W)


def normalize_C2W(C2W: Tensor, i: int = 0, norm_radius: float = 0.) -> Tensor:
    """Normalize a batch of multi-view `C2W` by the `i`-th view.

    Inputs:
        - `C2W`: (B, V, 4, 4)
        - `i`: the index of the view to normalize by
        - `norm_radius`: the normalization radius

    Outputs:
        - `normalized_C2W`: (B, V, 4, 4)
    """
    _, _, R, C = C2W.shape  # (B, V, 4, 4)
    assert R == C == 4

    device, dtype = C2W.device, C2W.dtype
    C2W = C2W.clone().float()

    if abs(norm_radius) > 0.:
        radius = torch.norm(C2W[:, i, :3, 3], dim=1)  # (B,)
        C2W[:, :, :3, 3] *= (norm_radius / radius.unsqueeze(1).unsqueeze(2))

    # The `i`-th view is normalized to a canonical matrix as the reference view
    transform = torch.tensor([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, norm_radius],
        [0, 0, 0, 1]  # canonical c2w in OpenGL world convention
    ], dtype=torch.float32, device=device) @ torch.inverse(C2W[:, i, ...])  # (B, 4, 4)

    return (transform.unsqueeze(1) @ C2W).to(dtype)  # (B, V, 4, 4)


def unproject_depth(depth_map: Tensor, C2W: Tensor, fxfycxcy: Tensor) -> Tensor:
    """Unproject depth map to 3D world coordinate.

    Inputs:
        - `depth_map`: (B, V, H, W)
        - `C2W`: (B, V, 4, 4)
        - `fxfycxcy`: (B, V, 4)

    Outputs:
        - `xyz_world`: (B, V, 3, H, W)
    """
    device, dtype = depth_map.device, depth_map.dtype
    B, V, H, W = depth_map.shape

    depth_map = depth_map.reshape(B*V, H, W).float()
    C2W = C2W.reshape(B*V, 4, 4).float()
    fxfycxcy = fxfycxcy.reshape(B*V, 4).float()
    K = torch.zeros(B*V, 3, 3, dtype=torch.float32, device=device)
    K[:, 0, 0] = fxfycxcy[:, 0]
    K[:, 1, 1] = fxfycxcy[:, 1]
    K[:, 0, 2] = fxfycxcy[:, 2]
    K[:, 1, 2] = fxfycxcy[:, 3]
    K[:, 2, 2] = 1

    y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")  # OpenCV/COLMAP camera convention
    y = y.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (H-1)
    x = x.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (W-1)
    # NOTE: To align with `plucker_ray(bug=False)`, should be:
    # y = (y.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / H
    # x = (x.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / W
    xyz_map = torch.stack([x, y, torch.ones_like(x)], axis=-1) * depth_map[..., None]
    xyz = xyz_map.view(B*V, -1, 3)

    # Get point positions in camera coordinate
    xyz = torch.matmul(xyz, torch.transpose(torch.inverse(K), 1, 2))
    xyz_map = xyz.view(B*V, H, W, 3)

    # Transform pts from camera to world coordinate
    xyz_homo = torch.ones((B*V, H, W, 4), device=device)
    xyz_homo[..., :3] = xyz_map
    xyz_world = torch.bmm(C2W, xyz_homo.reshape(B*V, -1, 4).permute(0, 2, 1))[:, :3, ...].to(dtype)  # (B*V, 3, H*W)
    xyz_world = xyz_world.reshape(B, V, 3, H, W)
    return xyz_world


def plucker_ray(h: int, w: int, C2W: Tensor, fxfycxcy: Tensor, bug: bool = True) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
    """Get Plucker ray embeddings.

    Inputs:
        - `h`: image height
        - `w`: image width
        - `C2W`: (B, V, 4, 4)
        - `fxfycxcy`: (B, V, 4)

    Outputs:
        - `plucker`: (B, V, 6, `h`, `w`)
        - `ray_o`: (B, V, 3, `h`, `w`)
        - `ray_d`: (B, V, 3, `h`, `w`)
    """
    device, dtype = C2W.device, C2W.dtype
    B, V = C2W.shape[:2]

    C2W = C2W.reshape(B*V, 4, 4).float()
    fxfycxcy = fxfycxcy.reshape(B*V, 4).float()

    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")  # OpenCV/COLMAP camera convention
    y, x = y.to(device), x.to(device)
    if bug:  # BUG !!! same here: https://github.com/camenduru/GRM/blob/master/model/visual_encoder/vit_gs.py#L85
        y = y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (h - 1)
        x = x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (w - 1)
        x = (x + 0.5 - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1]
        y = (y + 0.5 - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2]
    else:
        y = (y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / h
        x = (x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / w
        x = (x - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1]
        y = (y - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2]
    z = torch.ones_like(x)
    ray_d = torch.stack([x, y, z], dim=2)  # (B*V, h*w, 3)
    ray_d = torch.bmm(ray_d, C2W[:, :3, :3].transpose(1, 2))  # (B*V, h*w, 3)
    ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True)  # (B*V, h*w, 3)
    ray_o = C2W[:, :3, 3][:, None, :].expand_as(ray_d)  # (B*V, h*w, 3)

    ray_o = ray_o.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3)
    ray_d = ray_d.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3)
    plucker = torch.cat([torch.cross(ray_o, ray_d, dim=2).to(dtype), ray_d.to(dtype)], dim=2)

    return plucker, (ray_o, ray_d)


def orbit_camera(
    elevs: Tensor, azims: Tensor, radius: Optional[Tensor] = None,
    is_degree: bool = True,
    target: Optional[Tensor] = None,
    opengl: bool=True,
) -> Tensor:
    """Construct a camera pose matrix orbiting a target with elevation & azimuth angle.

    Inputs:
        - `elevs`: (B,); elevation in (-90, 90), from +y to -y is (-90, 90)
        - `azims`: (B,); azimuth in (-180, 180), from +z to +x is (0, 90)
        - `radius`: (B,); camera radius; if None, default to 1.
        - `is_degree`: bool; whether the input angles are in degree
        - `target`: (B, 3); look-at target position
        - `opengl`: bool; whether to use OpenGL convention

    Outputs:
        - `C2W`: (B, 4, 4); camera pose matrix
    """
    device, dtype = elevs.device, elevs.dtype

    if radius is None:
        radius = torch.ones_like(elevs)
    assert elevs.shape == azims.shape == radius.shape
    if target is None:
        target = torch.zeros(elevs.shape[0], 3, device=device, dtype=dtype)

    if is_degree:
        elevs = torch.deg2rad(elevs)
        azims = torch.deg2rad(azims)

    x = radius * torch.cos(elevs) * torch.sin(azims)
    y = - radius * torch.sin(elevs)
    z = radius * torch.cos(elevs) * torch.cos(azims)

    camposes = torch.stack([x, y, z], dim=1) + target  # (B, 3)
    R = look_at(camposes, target, opengl=opengl)  # (B, 3, 3)
    C2W = torch.cat([R, camposes[:, :, None]], dim=2)  # (B, 3, 4)
    C2W = torch.cat([C2W, torch.zeros_like(C2W[:, :1, :])], dim=1)  # (B, 4, 4)
    C2W[:, 3, 3] = 1.
    return C2W


def look_at(camposes: Tensor, targets: Tensor, opengl: bool = True) -> Tensor:
    """Construct batched pose rotation matrices by look-at.

    Inputs:
        - `camposes`: (B, 3); camera positions
        - `targets`: (B, 3); look-at targets
        - `opengl`: whether to use OpenGL convention

    Outputs:
        - `R`: (B, 3, 3); normalized camera pose rotation matrices
    """
    device, dtype = camposes.device, camposes.dtype

    if not opengl:  # OpenCV convention
        # forward is camera -> target
        forward_vectors = tF.normalize(targets - camposes, dim=-1)
        up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors)
        right_vectors = tF.normalize(torch.cross(forward_vectors, up_vectors), dim=-1)
        up_vectors = tF.normalize(torch.cross(right_vectors, forward_vectors), dim=-1)
    else:
        # forward is target -> camera
        forward_vectors = tF.normalize(camposes - targets, dim=-1)
        up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors)
        right_vectors = tF.normalize(torch.cross(up_vectors, forward_vectors), dim=-1)
        up_vectors = tF.normalize(torch.cross(forward_vectors, right_vectors), dim=-1)

    R = torch.stack([right_vectors, up_vectors, forward_vectors], dim=-1)
    return R