Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
# coding=utf-8 | |
import torch | |
def rotation_matrix(a, b): | |
"""Compute the rotation matrix that rotates vector a to vector b. | |
Args: | |
a: The vector to rotate. | |
b: The vector to rotate to. | |
Returns: | |
The rotation matrix. | |
""" | |
a = a / torch.linalg.norm(a) | |
b = b / torch.linalg.norm(b) | |
v = torch.cross(a, b) | |
c = torch.dot(a, b) | |
# If vectors are exactly opposite, we add a little noise to one of them | |
if c < -1 + 1e-8: | |
eps = (torch.rand(3) - 0.5) * 0.01 | |
return rotation_matrix(a + eps, b) | |
s = torch.linalg.norm(v) | |
skew_sym_mat = torch.Tensor( | |
[ | |
[0, -v[2], v[1]], | |
[v[2], 0, -v[0]], | |
[-v[1], v[0], 0], | |
] | |
) | |
return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8)) | |
def auto_orient_and_center_poses( | |
poses, method="up", center_poses=True | |
): | |
"""Orients and centers the poses. We provide two methods for orientation: pca and up. | |
pca: Orient the poses so that the principal component of the points is aligned with the axes. | |
This method works well when all of the cameras are in the same plane. | |
up: Orient the poses so that the average up vector is aligned with the z axis. | |
This method works well when images are not at arbitrary angles. | |
Args: | |
poses: The poses to orient. | |
method: The method to use for orientation. | |
center_poses: If True, the poses are centered around the origin. | |
Returns: | |
The oriented poses. | |
""" | |
translation = poses[..., :3, 3] | |
mean_translation = torch.mean(translation, dim=0) | |
translation_diff = translation - mean_translation | |
if center_poses: | |
translation = mean_translation | |
else: | |
translation = torch.zeros_like(mean_translation) | |
if method == "pca": | |
_, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff) | |
eigvec = torch.flip(eigvec, dims=(-1,)) | |
if torch.linalg.det(eigvec) < 0: | |
eigvec[:, 2] = -eigvec[:, 2] | |
transform = torch.cat([eigvec, eigvec @ -translation[..., None]], dim=-1) | |
oriented_poses = transform @ poses | |
if oriented_poses.mean(axis=0)[2, 1] < 0: | |
oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3] | |
elif method == "up": | |
up = torch.mean(poses[:, :3, 1], dim=0) | |
up = up / torch.linalg.norm(up) | |
rotation = rotation_matrix(up, torch.Tensor([0, 0, 1])) | |
transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1) | |
oriented_poses = transform @ poses | |
elif method == "none": | |
transform = torch.eye(4) | |
transform[:3, 3] = -translation | |
transform = transform[:3, :] | |
oriented_poses = transform @ poses | |
return oriented_poses, transform | |