Diffsplat / src /utils /geo_util.py
paulpanwang's picture
Upload folder using huggingface_hub
476e0f0 verified
from typing import *
from torch import Tensor
import torch
import torch.nn.functional as tF
def normalize_normals(normals: Tensor, C2W: Tensor, i: int = 0) -> Tensor:
"""Normalize a batch of multi-view `normals` by the `i`-th view.
Inputs:
- `normals`: (B, V, 3, H, W)
- `C2W`: (B, V, 4, 4)
- `i`: the index of the view to normalize by
Outputs:
- `normalized_normals`: (B, V, 3, H, W)
"""
_, _, R, C = C2W.shape # (B, V, 4, 4)
assert R == C == 4
_, _, CC, _, _ = normals.shape # (B, V, 3, H, W)
assert CC == 3
dtype = normals.dtype
normals = normals.clone().float()
transform = torch.inverse(C2W[:, i, :3, :3]) # (B, 3, 3)
return torch.einsum("brc,bvchw->bvrhw", transform, normals).to(dtype) # (B, V, 3, H, W)
def normalize_C2W(C2W: Tensor, i: int = 0, norm_radius: float = 0.) -> Tensor:
"""Normalize a batch of multi-view `C2W` by the `i`-th view.
Inputs:
- `C2W`: (B, V, 4, 4)
- `i`: the index of the view to normalize by
- `norm_radius`: the normalization radius
Outputs:
- `normalized_C2W`: (B, V, 4, 4)
"""
_, _, R, C = C2W.shape # (B, V, 4, 4)
assert R == C == 4
device, dtype = C2W.device, C2W.dtype
C2W = C2W.clone().float()
if abs(norm_radius) > 0.:
radius = torch.norm(C2W[:, i, :3, 3], dim=1) # (B,)
C2W[:, :, :3, 3] *= (norm_radius / radius.unsqueeze(1).unsqueeze(2))
# The `i`-th view is normalized to a canonical matrix as the reference view
transform = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, norm_radius],
[0, 0, 0, 1] # canonical c2w in OpenGL world convention
], dtype=torch.float32, device=device) @ torch.inverse(C2W[:, i, ...]) # (B, 4, 4)
return (transform.unsqueeze(1) @ C2W).to(dtype) # (B, V, 4, 4)
def unproject_depth(depth_map: Tensor, C2W: Tensor, fxfycxcy: Tensor) -> Tensor:
"""Unproject depth map to 3D world coordinate.
Inputs:
- `depth_map`: (B, V, H, W)
- `C2W`: (B, V, 4, 4)
- `fxfycxcy`: (B, V, 4)
Outputs:
- `xyz_world`: (B, V, 3, H, W)
"""
device, dtype = depth_map.device, depth_map.dtype
B, V, H, W = depth_map.shape
depth_map = depth_map.reshape(B*V, H, W).float()
C2W = C2W.reshape(B*V, 4, 4).float()
fxfycxcy = fxfycxcy.reshape(B*V, 4).float()
K = torch.zeros(B*V, 3, 3, dtype=torch.float32, device=device)
K[:, 0, 0] = fxfycxcy[:, 0]
K[:, 1, 1] = fxfycxcy[:, 1]
K[:, 0, 2] = fxfycxcy[:, 2]
K[:, 1, 2] = fxfycxcy[:, 3]
K[:, 2, 2] = 1
y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij") # OpenCV/COLMAP camera convention
y = y.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (H-1)
x = x.to(device).unsqueeze(0).repeat(B*V, 1, 1) / (W-1)
# NOTE: To align with `plucker_ray(bug=False)`, should be:
# y = (y.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / H
# x = (x.to(device).unsqueeze(0).repeat(B*V, 1, 1) + 0.5) / W
xyz_map = torch.stack([x, y, torch.ones_like(x)], axis=-1) * depth_map[..., None]
xyz = xyz_map.view(B*V, -1, 3)
# Get point positions in camera coordinate
xyz = torch.matmul(xyz, torch.transpose(torch.inverse(K), 1, 2))
xyz_map = xyz.view(B*V, H, W, 3)
# Transform pts from camera to world coordinate
xyz_homo = torch.ones((B*V, H, W, 4), device=device)
xyz_homo[..., :3] = xyz_map
xyz_world = torch.bmm(C2W, xyz_homo.reshape(B*V, -1, 4).permute(0, 2, 1))[:, :3, ...].to(dtype) # (B*V, 3, H*W)
xyz_world = xyz_world.reshape(B, V, 3, H, W)
return xyz_world
def plucker_ray(h: int, w: int, C2W: Tensor, fxfycxcy: Tensor, bug: bool = True) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
"""Get Plucker ray embeddings.
Inputs:
- `h`: image height
- `w`: image width
- `C2W`: (B, V, 4, 4)
- `fxfycxcy`: (B, V, 4)
Outputs:
- `plucker`: (B, V, 6, `h`, `w`)
- `ray_o`: (B, V, 3, `h`, `w`)
- `ray_d`: (B, V, 3, `h`, `w`)
"""
device, dtype = C2W.device, C2W.dtype
B, V = C2W.shape[:2]
C2W = C2W.reshape(B*V, 4, 4).float()
fxfycxcy = fxfycxcy.reshape(B*V, 4).float()
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") # OpenCV/COLMAP camera convention
y, x = y.to(device), x.to(device)
if bug: # BUG !!! same here: https://github.com/camenduru/GRM/blob/master/model/visual_encoder/vit_gs.py#L85
y = y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (h - 1)
x = x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) / (w - 1)
x = (x + 0.5 - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1]
y = (y + 0.5 - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2]
else:
y = (y[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / h
x = (x[None, :, :].expand(B*V, -1, -1).reshape(B*V, -1) + 0.5) / w
x = (x - fxfycxcy[:, 2:3]) / fxfycxcy[:, 0:1]
y = (y - fxfycxcy[:, 3:4]) / fxfycxcy[:, 1:2]
z = torch.ones_like(x)
ray_d = torch.stack([x, y, z], dim=2) # (B*V, h*w, 3)
ray_d = torch.bmm(ray_d, C2W[:, :3, :3].transpose(1, 2)) # (B*V, h*w, 3)
ray_d = ray_d / torch.norm(ray_d, dim=2, keepdim=True) # (B*V, h*w, 3)
ray_o = C2W[:, :3, 3][:, None, :].expand_as(ray_d) # (B*V, h*w, 3)
ray_o = ray_o.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3)
ray_d = ray_d.reshape(B, V, h, w, 3).permute(0, 1, 4, 2, 3)
plucker = torch.cat([torch.cross(ray_o, ray_d, dim=2).to(dtype), ray_d.to(dtype)], dim=2)
return plucker, (ray_o, ray_d)
def orbit_camera(
elevs: Tensor, azims: Tensor, radius: Optional[Tensor] = None,
is_degree: bool = True,
target: Optional[Tensor] = None,
opengl: bool=True,
) -> Tensor:
"""Construct a camera pose matrix orbiting a target with elevation & azimuth angle.
Inputs:
- `elevs`: (B,); elevation in (-90, 90), from +y to -y is (-90, 90)
- `azims`: (B,); azimuth in (-180, 180), from +z to +x is (0, 90)
- `radius`: (B,); camera radius; if None, default to 1.
- `is_degree`: bool; whether the input angles are in degree
- `target`: (B, 3); look-at target position
- `opengl`: bool; whether to use OpenGL convention
Outputs:
- `C2W`: (B, 4, 4); camera pose matrix
"""
device, dtype = elevs.device, elevs.dtype
if radius is None:
radius = torch.ones_like(elevs)
assert elevs.shape == azims.shape == radius.shape
if target is None:
target = torch.zeros(elevs.shape[0], 3, device=device, dtype=dtype)
if is_degree:
elevs = torch.deg2rad(elevs)
azims = torch.deg2rad(azims)
x = radius * torch.cos(elevs) * torch.sin(azims)
y = - radius * torch.sin(elevs)
z = radius * torch.cos(elevs) * torch.cos(azims)
camposes = torch.stack([x, y, z], dim=1) + target # (B, 3)
R = look_at(camposes, target, opengl=opengl) # (B, 3, 3)
C2W = torch.cat([R, camposes[:, :, None]], dim=2) # (B, 3, 4)
C2W = torch.cat([C2W, torch.zeros_like(C2W[:, :1, :])], dim=1) # (B, 4, 4)
C2W[:, 3, 3] = 1.
return C2W
def look_at(camposes: Tensor, targets: Tensor, opengl: bool = True) -> Tensor:
"""Construct batched pose rotation matrices by look-at.
Inputs:
- `camposes`: (B, 3); camera positions
- `targets`: (B, 3); look-at targets
- `opengl`: whether to use OpenGL convention
Outputs:
- `R`: (B, 3, 3); normalized camera pose rotation matrices
"""
device, dtype = camposes.device, camposes.dtype
if not opengl: # OpenCV convention
# forward is camera -> target
forward_vectors = tF.normalize(targets - camposes, dim=-1)
up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors)
right_vectors = tF.normalize(torch.cross(forward_vectors, up_vectors), dim=-1)
up_vectors = tF.normalize(torch.cross(right_vectors, forward_vectors), dim=-1)
else:
# forward is target -> camera
forward_vectors = tF.normalize(camposes - targets, dim=-1)
up_vectors = torch.tensor([0., 1., 0.], device=device, dtype=dtype)[None, :].expand_as(forward_vectors)
right_vectors = tF.normalize(torch.cross(up_vectors, forward_vectors), dim=-1)
up_vectors = tF.normalize(torch.cross(forward_vectors, right_vectors), dim=-1)
R = torch.stack([right_vectors, up_vectors, forward_vectors], dim=-1)
return R