Spaces:
Running
Running
from typing import Optional | |
import torch | |
from torch.nn import functional as F | |
def aa_to_rotmat(theta: torch.Tensor): | |
""" | |
Convert axis-angle representation to rotation matrix. | |
Works by first converting it to a quaternion. | |
Args: | |
theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations. | |
Returns: | |
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). | |
""" | |
norm = torch.norm(theta + 1e-8, p=2, dim=1) | |
angle = torch.unsqueeze(norm, -1) | |
normalized = torch.div(theta, angle) | |
angle = angle * 0.5 | |
v_cos = torch.cos(angle) | |
v_sin = torch.sin(angle) | |
quat = torch.cat([v_cos, v_sin * normalized], dim=1) | |
return quat_to_rotmat(quat) | |
def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert quaternion representation to rotation matrix. | |
Args: | |
quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z). | |
Returns: | |
torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3). | |
""" | |
norm_quat = quat | |
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) | |
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3] | |
B = quat.size(0) | |
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) | |
wx, wy, wz = w * x, w * y, w * z | |
xy, xz, yz = x * y, x * z, y * z | |
rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, | |
2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, | |
2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) | |
return rotMat | |
def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert 6D rotation representation to 3x3 rotation matrix. | |
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 | |
Args: | |
x (torch.Tensor): (B,6) Batch of 6-D rotation representations. | |
Returns: | |
torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3). | |
""" | |
x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous() | |
a1 = x[:, :, 0] | |
a2 = x[:, :, 1] | |
b1 = F.normalize(a1) | |
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) | |
b3 = torch.cross(b1, b2, dim=1) | |
return torch.stack((b1, b2, b3), dim=-1) | |
def perspective_projection(points: torch.Tensor, | |
translation: torch.Tensor, | |
focal_length: torch.Tensor, | |
camera_center: Optional[torch.Tensor] = None, | |
rotation: Optional[torch.Tensor] = None) -> torch.Tensor: | |
""" | |
Computes the perspective projection of a set of 3D points. | |
Args: | |
points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. | |
translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. | |
focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. | |
camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. | |
rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. | |
Returns: | |
torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. | |
""" | |
batch_size = points.shape[0] | |
if rotation is None: | |
rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) | |
if camera_center is None: | |
camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) | |
# Populate intrinsic camera matrix K. | |
K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) | |
K[:, 0, 0] = focal_length[:, 0] | |
K[:, 1, 1] = focal_length[:, 1] | |
K[:, 2, 2] = 1. | |
K[:, :-1, -1] = camera_center | |
# Transform points | |
points = torch.einsum('bij,bkj->bki', rotation, points) | |
points = points + translation.unsqueeze(1) | |
# Apply perspective distortion | |
projected_points = points / points[:, :, -1].unsqueeze(-1) | |
# Apply camera intrinsics | |
projected_points = torch.einsum('bij,bkj->bki', K, projected_points) | |
return projected_points[:, :, :-1] | |