Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from torch import Tensor | |
import torch | |
import torch.nn.functional as tF | |
def normalize_normals(normals: Tensor, C2W: Tensor, i: int = 0) -> Tensor: | |
"""Normalize a batch of multi-view `normals` by the `i`-th view. | |
Inputs: | |
- `normals`: (B, V, 3, H, W) | |
- `C2W`: (B, V, 4, 4) | |
- `i`: the index of the view to normalize by | |
Outputs: | |
- `normalized_normals`: (B, V, 3, H, W) | |
""" | |
_, _, R, C = C2W.shape # (B, V, 4, 4) | |
assert R == C == 4 | |
_, _, CC, _, _ = normals.shape # (B, V, 3, H, W) | |
assert CC == 3 | |
dtype = normals.dtype | |
normals = normals.clone().float() | |
transform = torch.inverse(C2W[:, i, :3, :3]) # (B, 3, 3) | |
return torch.einsum("brc,bvchw->bvrhw", transform, normals).to(dtype) # (B, V, 3, H, W) | |
def normalize_C2W(C2W: Tensor, i: int = 0, norm_radius: float = 0.) -> Tensor: | |
"""Normalize a batch of multi-view `C2W` by the `i`-th view. | |
Inputs: | |
- `C2W`: (B, V, 4, 4) | |
- `i`: the index of the view to normalize by | |
- `norm_radius`: the normalization radius | |
Outputs: | |
- `normalized_C2W`: (B, V, 4, 4) | |
""" | |
_, _, R, C = C2W.shape # (B, V, 4, 4) | |
assert R == C == 4 | |
device, dtype = C2W.device, C2W.dtype | |
C2W = C2W.clone().float() | |
if abs(norm_radius) > 0.: | |
radius = torch.norm(C2W[:, i, :3, 3], dim=1) # (B,) | |
C2W[:, :, :3, 3] *= (norm_radius / radius.unsqueeze(1).unsqueeze(2)) | |
# The `i`-th view is normalized to a canonical matrix as the reference view | |
transform = torch.tensor([ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 1, norm_radius], | |
[0, 0, 0, 1] # canonical c2w in OpenGL world convention | |
], dtype=torch.float32, device=device) @ torch.inverse(C2W[:, i, ...]) # (B, 4, 4) | |
return (transform.unsqueeze(1) @ C2W).to(dtype) # (B, V, 4, 4) | |
def unproject_depth(depth_map: Tensor, C2W: Tensor, fxfycxcy: Tensor) -> Tensor: | |
"""Unproject depth map to 3D world coordinate. | |
Inputs: | |
- `depth_map`: (B, V, H, W) | |
- `C2W`: (B, V, 4, 4) | |
- `fxfycxcy`: (B, V, 4) | |
Outputs: | |
- `xyz_world`: (B, V, 3, H, W) | |
""" | |
device, dtype = depth_map.device, depth_map.dtype | |
B, V, H, W = depth_map.shape | |
depth_map = depth_map.reshape(B*V, H, W).float() | |
C2W = C2W.reshape(B*V, 4, 4).float() | |
fxfycxcy = fxfycxcy.reshape(B*V, 4).float() | |
K = torch.zeros(B*V, 3, 3, dtype=torch.float32, device=device) | |
K[:, 0, 0] = fxfycxcy[:, 0] | |
K[:, 1, 1] = fxfycxcy[:, 1] | |
K[:, 0, 2] = fxfycxcy[:, 2] | |
K[:, 1, 2] = fxfycxcy[:, 3] | |
K[:, 2, 2] = 1 | |
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") # OpenCV/COLMAP camera convention | |
y = y.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (H-1) | |
x = x.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (W-1) | |
# NOTE: To align with `plucker_ray(bug=False)`, should be: | |
# y = (y.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / H | |
# x = (x.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / W | |
xyz_map = torch.stack([x, y, torch.ones_like(x)], axis=-1) * depth_map[..., None] | |
xyz = xyz_map.view(B*V, -1, 3) | |
# Get point positions in camera coordinate | |
xyz = torch.matmul(xyz, torch.transpose(torch.inverse(K), 1, 2)) | |
xyz_map = xyz.view(B*V, H, W, 3) | |
# Transform pts from camera to world coordinate | |
xyz_homo = torch.ones((B*V, H, W, 4), device=device) | |
xyz_homo[..., :3] = xyz_map | |
xyz_world = torch.bmm(C2W, xyz_homo.reshape(B*V, -1, 4).permute(0, 2, 1))[:, :3, ...].to(dtype) # (B*V, 3, H*W) | |
xyz_world = xyz_world.reshape(B, V, 3, H, W) | |
return xyz_world | |
def plucker_ray(h: int, w: int, C2W: Tensor, fxfycxcy: Tensor, bug: bool = True) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: | |
"""Get Plucker ray embeddings. | |
Inputs: | |
- `h`: image height | |
- `w`: image width | |
- `C2W`: (B, V, 4, 4) | |
- `fxfycxcy`: (B, V, 4) | |
Outputs: | |
- `plucker`: (B, V, 6, `h`, `w`) | |
- `ray_o`: (B, V, 3, `h`, `w`) | |
- `ray_d`: (B, V, 3, `h`, `w`) | |
""" | |
device, dtype = C2W.device, C2W.dtype | |
B, V = C2W.shape[:2] | |
C2W = C2W.reshape(B*V, 4, 4).float() | |
fxfycxcy = fxfycxcy.reshape(B*V, 4).float() | |
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") # OpenCV/COLMAP camera convention | |
y, x = y.to(device), x.to(device) | |
if bug: # BUG !!! same here: https://github.com/camenduru/GRM/blob/master/model/visual_encoder/vit_gs.py#L85 | |
y = y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (h - 1) | |
x = x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (w - 1) | |
x = (x + 0.5 - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1] | |
y = (y + 0.5 - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2] | |
else: | |
y = (y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / h | |
x = (x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / w | |
x = (x - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1] | |
y = (y - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2] | |
z = torch.ones_like(x) | |
ray_d = torch.stack([x, y, z], dim=2) # (B*V, h*w, 3) | |
ray_d = torch.bmm(ray_d, C2W[:, :3, :3].transpose(1, 2)) # (B*V, h*w, 3) | |
ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True) # (B*V, h*w, 3) | |
ray_o = C2W[:, :3, 3][:, None, :].expand_as(ray_d) # (B*V, h*w, 3) | |
ray_o = ray_o.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3) | |
ray_d = ray_d.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3) | |
plucker = torch.cat([torch.cross(ray_o, ray_d, dim=2).to(dtype), ray_d.to(dtype)], dim=2) | |
return plucker, (ray_o, ray_d) | |
def orbit_camera( | |
elevs: Tensor, azims: Tensor, radius: Optional[Tensor] = None, | |
is_degree: bool = True, | |
target: Optional[Tensor] = None, | |
opengl: bool=True, | |
) -> Tensor: | |
"""Construct a camera pose matrix orbiting a target with elevation & azimuth angle. | |
Inputs: | |
- `elevs`: (B,); elevation in (-90, 90), from +y to -y is (-90, 90) | |
- `azims`: (B,); azimuth in (-180, 180), from +z to +x is (0, 90) | |
- `radius`: (B,); camera radius; if None, default to 1. | |
- `is_degree`: bool; whether the input angles are in degree | |
- `target`: (B, 3); look-at target position | |
- `opengl`: bool; whether to use OpenGL convention | |
Outputs: | |
- `C2W`: (B, 4, 4); camera pose matrix | |
""" | |
device, dtype = elevs.device, elevs.dtype | |
if radius is None: | |
radius = torch.ones_like(elevs) | |
assert elevs.shape == azims.shape == radius.shape | |
if target is None: | |
target = torch.zeros(elevs.shape[0], 3, device=device, dtype=dtype) | |
if is_degree: | |
elevs = torch.deg2rad(elevs) | |
azims = torch.deg2rad(azims) | |
x = radius * torch.cos(elevs) * torch.sin(azims) | |
y = - radius * torch.sin(elevs) | |
z = radius * torch.cos(elevs) * torch.cos(azims) | |
camposes = torch.stack([x, y, z], dim=1) + target # (B, 3) | |
R = look_at(camposes, target, opengl=opengl) # (B, 3, 3) | |
C2W = torch.cat([R, camposes[:, :, None]], dim=2) # (B, 3, 4) | |
C2W = torch.cat([C2W, torch.zeros_like(C2W[:, :1, :])], dim=1) # (B, 4, 4) | |
C2W[:, 3, 3] = 1. | |
return C2W | |
def look_at(camposes: Tensor, targets: Tensor, opengl: bool = True) -> Tensor: | |
"""Construct batched pose rotation matrices by look-at. | |
Inputs: | |
- `camposes`: (B, 3); camera positions | |
- `targets`: (B, 3); look-at targets | |
- `opengl`: whether to use OpenGL convention | |
Outputs: | |
- `R`: (B, 3, 3); normalized camera pose rotation matrices | |
""" | |
device, dtype = camposes.device, camposes.dtype | |
if not opengl: # OpenCV convention | |
# forward is camera -> target | |
forward_vectors = tF.normalize(targets - camposes, dim=-1) | |
up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors) | |
right_vectors = tF.normalize(torch.cross(forward_vectors, up_vectors), dim=-1) | |
up_vectors = tF.normalize(torch.cross(right_vectors, forward_vectors), dim=-1) | |
else: | |
# forward is target -> camera | |
forward_vectors = tF.normalize(camposes - targets, dim=-1) | |
up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors) | |
right_vectors = tF.normalize(torch.cross(up_vectors, forward_vectors), dim=-1) | |
up_vectors = tF.normalize(torch.cross(forward_vectors, right_vectors), dim=-1) | |
R = torch.stack([right_vectors, up_vectors, forward_vectors], dim=-1) | |
return R | |