Datum-3D / mv_utils_zs.py
TeeA's picture
refactor
5e4b407
# %%writefile mv_utils_zs.py
"""
Author: yangyangyang127
Github: https://github.com/yangyangyang127
Repo: https://github.com/yangyangyang127/PointCLIP_V2
Path: https://github.com/yangyangyang127/PointCLIP_V2/blob/main/zeroshot_cls/trainers/mv_utils_zs.py#L135
"""
import numpy as np
import torch
import torch.nn as nn
from torch_scatter import scatter
TRANS = -1.5
# realistic projection parameters
params = {
"maxpoolz": 1,
"maxpoolxy": 7,
"maxpoolpadz": 0,
"maxpoolpadxy": 2,
"convz": 1,
"convxy": 3,
"convsigmaxy": 3,
"convsigmaz": 1,
"convpadz": 0,
"convpadxy": 1,
"imgbias": 0.0,
"depth_bias": 0.2,
"obj_ratio": 0.8,
"bg_clr": 0.0,
"resolution": 122,
"depth": 8, # default = 8
"grid_height": 64,
"grid_width": 64,
}
class Grid2Image(nn.Module):
"""A pytorch implementation to turn 3D grid to 2D image.
Maxpool: densifying the grid
Convolution: smoothing via Gaussian
Maximize: squeezing the depth channel
"""
def __init__(self):
super().__init__()
torch.backends.cudnn.benchmark = False
self.maxpool = nn.MaxPool3d(
(params["maxpoolz"], params["maxpoolxy"], params["maxpoolxy"]),
stride=1,
padding=(
params["maxpoolpadz"],
params["maxpoolpadxy"],
params["maxpoolpadxy"],
),
)
self.conv = torch.nn.Conv3d(
1,
1,
kernel_size=(params["convz"], params["convxy"], params["convxy"]),
stride=1,
padding=(params["convpadz"], params["convpadxy"], params["convpadxy"]),
bias=True,
)
kn3d = get3DGaussianKernel(
params["convxy"],
params["convz"],
sigma=params["convsigmaxy"],
zsigma=params["convsigmaz"],
)
self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
self.conv.bias.data.fill_(0) # type: ignore
def forward(self, x):
x = self.maxpool(x.unsqueeze(1))
x = self.conv(x)
img = torch.max(x, dim=2)[0]
img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]
img = 1 - img
img = img.repeat(1, 3, 1, 1)
return img
def euler2mat(angle):
"""Convert euler angles to rotation matrix.
:param angle: [3] or [b, 3]
:return
rotmat: [3] or [b, 3, 3]
source
https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
"""
if len(angle.size()) == 1:
x, y, z = angle[0], angle[1], angle[2]
_dim = 0
_view = [3, 3]
elif len(angle.size()) == 2:
b, _ = angle.size()
x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
_dim = 1
_view = [b, 3, 3]
else:
assert False
cosz = torch.cos(z)
sinz = torch.sin(z)
# zero = torch.zeros([b], requires_grad=False, device=angle.device)[0]
# one = torch.ones([b], requires_grad=False, device=angle.device)[0]
zero = z.detach() * 0
one = zero.detach() + 1
zmat = torch.stack(
[cosz, -sinz, zero, sinz, cosz, zero, zero, zero, one], dim=_dim
).reshape(_view)
cosy = torch.cos(y)
siny = torch.sin(y)
ymat = torch.stack(
[cosy, zero, siny, zero, one, zero, -siny, zero, cosy], dim=_dim
).reshape(_view)
cosx = torch.cos(x)
sinx = torch.sin(x)
xmat = torch.stack(
[one, zero, zero, zero, cosx, -sinx, zero, sinx, cosx], dim=_dim
).reshape(_view)
rot_mat = xmat @ ymat @ zmat
# print(rot_mat)
return rot_mat
def points_to_2d_grid(
points, grid_h=params["grid_height"], grid_w=params["grid_width"]
):
"""
Converts a point cloud into a 2D grid based on X, Y coordinates.
Points are projected onto a plane and quantized into grid cells.
Args:
points (torch.tensor): Tensor containing points, shape [B, P, 3]
(B: batch size, P: number of points, 3: x, y, z coordinates)
grid_h (int): Height of the output 2D grid.
grid_w (int): Width of the output 2D grid.
Returns:
grid (torch.tensor): 2D grid representing the occupancy of points,
shape [B, grid_h, grid_w].
Value 1.0 at cell (y, x) if at least one point falls into it,
otherwise the background value (params["bg_clr"]).
"""
batch, pnum, _ = points.shape
device = points.device
# --- Step 1: Normalize point coordinates ---
# Find min/max for each point cloud in the batch (considering only X, Y for better 2D normalization)
pmax_xy = points[:, :, :2].max(dim=1)[0]
pmin_xy = points[:, :, :2].min(dim=1)[0]
# Compute the center and range based on X, Y
pcent_xy = (pmax_xy + pmin_xy) / 2
pcent_xy = pcent_xy[:, None, :] # Add P dimension for broadcasting [B, 1, 2]
# Use the larger range between X and Y to maintain aspect ratio
prange_xy = (pmax_xy - pmin_xy).max(dim=-1)[0][:, None, None] # [B, 1, 1]
# Add a small epsilon to avoid division by zero if all points overlap
epsilon = 1e-8
# Normalize X, Y into the range [-1, 1] based on the X, Y range
points_normalized_xy = (points[:, :, :2] - pcent_xy) / (prange_xy + epsilon) * 2.0
# Adjust the scale according to obj_ratio (if needed)
points_normalized_xy = points_normalized_xy * params["obj_ratio"]
# --- Step 2: Map normalized coordinates to 2D grid indices ---
# Map X from the range [-obj_ratio, obj_ratio] -> [0, grid_w]
# Map Y from the range [-obj_ratio, obj_ratio] -> [0, grid_h]
# General formula: (normalized_coord + scale) / (2 * scale) * grid_dim
_x = (
(points_normalized_xy[:, :, 0] + params["obj_ratio"])
/ (2 * params["obj_ratio"])
* grid_w
)
_y = (
(points_normalized_xy[:, :, 1] + params["obj_ratio"])
/ (2 * params["obj_ratio"])
* grid_h
)
# Round down to determine the grid cell indices
_x = torch.floor(_x).long()
_y = torch.floor(_y).long()
# --- Step 3: Clamp indices to valid grid range ---
# Clip _x to [0, grid_w - 1]
# Clip _y to [0, grid_h - 1]
_x = torch.clip(_x, 0, grid_w - 1)
_y = torch.clip(_y, 0, grid_h - 1)
# --- Step 4: Create a 2D grid and mark occupied cells ---
# Initialize the 2D grid with the background value
grid = torch.full(
(batch, grid_h, grid_w), params["bg_clr"], dtype=torch.float32, device=device
)
# Create batch indices corresponding to each point
batch_indices = torch.arange(batch, device=device).view(-1, 1).repeat(1, pnum)
# Flatten indices for easier assignment
batch_idx_flat = batch_indices.view(-1)
y_idx_flat = _y.view(-1)
x_idx_flat = _x.view(-1)
# Assign a value of 1.0 to grid cells (y, x) corresponding to point positions
# If multiple points fall into the same cell, the cell still has a value of 1.0
grid[batch_idx_flat, y_idx_flat, x_idx_flat] = 1.0
return grid
def points2grid(points, resolution=params["resolution"], depth=params["depth"]):
"""Quantize each point cloud to a 3D grid.
Args:
points (torch.tensor): of size [B, _, 3]
Returns:
grid (torch.tensor): of size [B * self.num_views, depth, resolution, resolution]
"""
batch, pnum, _ = points.shape
pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
pcent = (pmax + pmin) / 2
pcent = pcent[:, None, :]
prange = (pmax - pmin).max(dim=-1)[0][:, None, None]
points = (points - pcent) / prange * 2.0
points[:, :, :2] = points[:, :, :2] * params["obj_ratio"]
depth_bias = params["depth_bias"]
_x = (points[:, :, 0] + 1) / 2 * resolution
_y = (points[:, :, 1] + 1) / 2 * resolution
_z = ((points[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)
_x.ceil_()
_y.ceil_()
z_int = _z.ceil()
_x = torch.clip(_x, 1, resolution - 2)
_y = torch.clip(_y, 1, resolution - 2)
_z = torch.clip(_z, 1, depth - 2)
coordinates = z_int * resolution * resolution + _y * resolution + _x
grid = (
torch.ones([batch, depth, resolution, resolution], device=points.device).view(
batch, -1
)
* params["bg_clr"]
)
grid = scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))
return grid
def points_to_occupancy_grid(
points, resolution=params["resolution"], depth=params["depth"]
):
"""Quantize each point cloud into a 3D occupancy grid."""
batch, pnum, _ = points.shape
device = points.device # Get device to create new tensors
# --- Normalization and coordinate mapping remain unchanged ---
pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
pcent = (pmax + pmin) / 2
pcent = pcent[:, None, :]
prange = (pmax - pmin).max(dim=-1)[0][
:, None, None
] + 1e-8 # Add epsilon to avoid division by zero
points_norm = (points - pcent) / prange * 2.0
points_norm[:, :, :2] = points_norm[:, :, :2] * params["obj_ratio"]
depth_bias = params["depth_bias"]
_x = (points_norm[:, :, 0] + 1) / 2 * resolution
_y = (points_norm[:, :, 1] + 1) / 2 * resolution
_z = ((points_norm[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)
_x.ceil_()
_y.ceil_()
z_int = _z.ceil()
_x = torch.clip(_x, 1, resolution - 2)
_y = torch.clip(_y, 1, resolution - 2)
# z_int should also be clipped if used as coordinate indices
z_int = torch.clip(z_int, 1, depth - 2)
# --- Compute flattened coordinates ---
coordinates = z_int * resolution * resolution + _y * resolution + _x
coordinates = coordinates.long() # Convert to Long
# --- Create Grid and Scatter ---
# Initialize the grid with the background value (e.g., 0)
# Use torch.zeros instead of torch.ones and multiply by bg_clr
bg_clr_value = params.get("bg_clr", 0.0) # Get bg_clr, default is 0
grid = torch.full(
(batch, depth * resolution * resolution),
bg_clr_value,
dtype=torch.float32, # Or appropriate dtype
device=device,
)
# Create a source tensor (src) containing a value of 1.0 for each point
# The size must match the flattened coordinates: [B * pnum]
values_to_scatter = torch.ones(batch * pnum, dtype=torch.float32, device=device)
# Scatter the value 1.0 into the grid at the positions `coordinates`
# Use reduce="max". If a cell has at least one point, max(1.0, bg_clr) will be 1.0 (if bg_clr <= 1)
# To ensure the value is always 1 regardless of bg_clr, use a different reduce or post-process after scatter.
# A safer choice if bg_clr can be > 1 is to initialize the grid with 0 and use reduce='max'/'mean'
# Or initialize with bg_clr and process after scatter.
if bg_clr_value != 0.0:
print(
"Warning: bg_clr is not 0.0, occupancy grid might not be strictly binary 0/1 with reduce='max'. Consider initializing grid with 0."
)
grid = scatter(
values_to_scatter,
coordinates.view(-1), # Flatten coordinates to [B*pnum]
dim=0, # Scatter along dimension 0 of the flattened grid [B*D*R*R]
out=grid.view(-1), # Flatten grid to [B*D*R*R] for scatter along dim 0
reduce="max",
) # If a point exists -> cell value is 1, otherwise bg_clr
# --- Reshape and Permute remain unchanged ---
# Reshape the grid back to the correct 3D + batch size
# Note: scatter into a flattened grid requires careful reshaping
grid = grid.view(batch, depth, resolution, resolution) # Reshape back
grid = grid.permute((0, 1, 3, 2))
return grid
class Realistic_Projection:
"""For creating images from PC based on the view information."""
def __init__(self):
_views = np.asarray([
[[1 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[0 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[1 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[2 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
[[0, -np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
[[0, np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
])
# adding some bias to the view angle to reveal more surface
_views_bias = np.asarray([
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
[[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
])
self.num_views = _views.shape[0]
angle = torch.tensor(_views[:, 0, :]).float() # .cuda()
self.rot_mat = euler2mat(angle).transpose(1, 2)
angle2 = torch.tensor(_views_bias[:, 0, :]).float() # .cuda()
self.rot_mat2 = euler2mat(angle2).transpose(1, 2)
self.translation = torch.tensor(_views[:, 1, :]).float() # .cuda()
self.translation = self.translation.unsqueeze(1)
self.grid2image = Grid2Image() # .cuda()
def get_img(self, points):
b, _, _ = points.shape
v = self.translation.shape[0]
_points = self.point_transform(
points=torch.repeat_interleave(points, v, dim=0),
rot_mat=self.rot_mat.repeat(b, 1, 1),
rot_mat2=self.rot_mat2.repeat(b, 1, 1),
translation=self.translation.repeat(b, 1, 1),
)
grid = points2grid(
points=_points, resolution=params["resolution"], depth=params["depth"]
).squeeze()
img = self.grid2image(grid)
return img
@staticmethod
def point_transform(points, rot_mat, rot_mat2, translation):
"""
:param points: [batch, num_points, 3]
:param rot_mat: [batch, 3]
:param rot_mat2: [batch, 3]
:param translation: [batch, 1, 3]
:return:
"""
rot_mat = rot_mat.to(points.device)
rot_mat2 = rot_mat2.to(points.device)
translation = translation.to(points.device)
points = torch.matmul(points, rot_mat)
points = torch.matmul(points, rot_mat2)
points = points - translation
return points
def get2DGaussianKernel(ksize, sigma=0):
center = ksize // 2
xs = np.arange(ksize, dtype=np.float32) - center
kernel1d = np.exp(-(xs**2) / (2 * sigma**2))
kernel = kernel1d[..., None] @ kernel1d[None, ...]
kernel = torch.from_numpy(kernel)
kernel = kernel / kernel.sum()
return kernel
# Without numpy
# def get2DGaussianKernel(ksize, sigma):
# xs = torch.linspace(-(ksize // 2), ksize // 2, steps=ksize)
# kernel1d = torch.exp(-(xs ** 2) / (2 * sigma ** 2))
# kernel2d = torch.outer(kernel1d, kernel1d)
# kernel2d /= kernel2d.sum()
# return kernel2d
def get3DGaussianKernel(ksize, depth, sigma=2, zsigma=2):
kernel2d = get2DGaussianKernel(ksize, sigma)
zs = np.arange(depth, dtype=np.float32) - depth // 2
zkernel = np.exp(-(zs**2) / (2 * zsigma**2))
kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * zkernel[:, None, None]
kernel3d = kernel3d / torch.sum(kernel3d)
return kernel3d