File size: 15,892 Bytes
d6cfb5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e4b407
d6cfb5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e4b407
 
d6cfb5e
 
5e4b407
 
 
 
d6cfb5e
 
5e4b407
 
 
 
d6cfb5e
 
 
 
5e4b407
 
d6cfb5e
 
 
5e4b407
d6cfb5e
5e4b407
d6cfb5e
5e4b407
d6cfb5e
 
5e4b407
d6cfb5e
5e4b407
d6cfb5e
 
5e4b407
d6cfb5e
 
5e4b407
 
 
 
d6cfb5e
 
 
 
 
 
 
 
 
 
 
5e4b407
d6cfb5e
 
 
5e4b407
 
 
d6cfb5e
 
 
5e4b407
 
d6cfb5e
 
 
 
5e4b407
d6cfb5e
 
5e4b407
d6cfb5e
 
 
 
5e4b407
 
d6cfb5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e4b407
d6cfb5e
 
5e4b407
d6cfb5e
5e4b407
d6cfb5e
 
 
 
 
5e4b407
d6cfb5e
 
 
 
 
 
 
 
 
 
 
 
 
 
5e4b407
d6cfb5e
 
5e4b407
d6cfb5e
5e4b407
d6cfb5e
5e4b407
 
 
 
d6cfb5e
 
 
5e4b407
d6cfb5e
 
 
5e4b407
 
d6cfb5e
 
5e4b407
 
 
 
 
d6cfb5e
 
 
 
 
 
 
5e4b407
 
 
d6cfb5e
5e4b407
d6cfb5e
5e4b407
 
 
 
d6cfb5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
# %%writefile mv_utils_zs.py
"""
Author: yangyangyang127
Github: https://github.com/yangyangyang127
Repo: https://github.com/yangyangyang127/PointCLIP_V2
Path: https://github.com/yangyangyang127/PointCLIP_V2/blob/main/zeroshot_cls/trainers/mv_utils_zs.py#L135
"""

import numpy as np
import torch
import torch.nn as nn
from torch_scatter import scatter

TRANS = -1.5

# realistic projection parameters
params = {
    "maxpoolz": 1,
    "maxpoolxy": 7,
    "maxpoolpadz": 0,
    "maxpoolpadxy": 2,
    "convz": 1,
    "convxy": 3,
    "convsigmaxy": 3,
    "convsigmaz": 1,
    "convpadz": 0,
    "convpadxy": 1,
    "imgbias": 0.0,
    "depth_bias": 0.2,
    "obj_ratio": 0.8,
    "bg_clr": 0.0,
    "resolution": 122,
    "depth": 8,  # default = 8
    "grid_height": 64,
    "grid_width": 64,
}


class Grid2Image(nn.Module):
    """A pytorch implementation to turn 3D grid to 2D image.
    Maxpool: densifying the grid
    Convolution: smoothing via Gaussian
    Maximize: squeezing the depth channel
    """

    def __init__(self):
        super().__init__()
        torch.backends.cudnn.benchmark = False

        self.maxpool = nn.MaxPool3d(
            (params["maxpoolz"], params["maxpoolxy"], params["maxpoolxy"]),
            stride=1,
            padding=(
                params["maxpoolpadz"],
                params["maxpoolpadxy"],
                params["maxpoolpadxy"],
            ),
        )
        self.conv = torch.nn.Conv3d(
            1,
            1,
            kernel_size=(params["convz"], params["convxy"], params["convxy"]),
            stride=1,
            padding=(params["convpadz"], params["convpadxy"], params["convpadxy"]),
            bias=True,
        )
        kn3d = get3DGaussianKernel(
            params["convxy"],
            params["convz"],
            sigma=params["convsigmaxy"],
            zsigma=params["convsigmaz"],
        )
        self.conv.weight.data = torch.Tensor(kn3d).repeat(1, 1, 1, 1, 1)
        self.conv.bias.data.fill_(0)  # type: ignore

    def forward(self, x):
        x = self.maxpool(x.unsqueeze(1))
        x = self.conv(x)
        img = torch.max(x, dim=2)[0]
        img = img / torch.max(torch.max(img, dim=-1)[0], dim=-1)[0][:, :, None, None]
        img = 1 - img
        img = img.repeat(1, 3, 1, 1)
        return img


def euler2mat(angle):
    """Convert euler angles to rotation matrix.
     :param angle: [3] or [b, 3]
     :return
        rotmat: [3] or [b, 3, 3]
    source
    https://github.com/ClementPinard/SfmLearner-Pytorch/blob/master/inverse_warp.py
    """
    if len(angle.size()) == 1:
        x, y, z = angle[0], angle[1], angle[2]
        _dim = 0
        _view = [3, 3]
    elif len(angle.size()) == 2:
        b, _ = angle.size()
        x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
        _dim = 1
        _view = [b, 3, 3]

    else:
        assert False

    cosz = torch.cos(z)
    sinz = torch.sin(z)

    # zero = torch.zeros([b], requires_grad=False, device=angle.device)[0]
    # one = torch.ones([b], requires_grad=False, device=angle.device)[0]
    zero = z.detach() * 0
    one = zero.detach() + 1
    zmat = torch.stack(
        [cosz, -sinz, zero, sinz, cosz, zero, zero, zero, one], dim=_dim
    ).reshape(_view)

    cosy = torch.cos(y)
    siny = torch.sin(y)

    ymat = torch.stack(
        [cosy, zero, siny, zero, one, zero, -siny, zero, cosy], dim=_dim
    ).reshape(_view)

    cosx = torch.cos(x)
    sinx = torch.sin(x)

    xmat = torch.stack(
        [one, zero, zero, zero, cosx, -sinx, zero, sinx, cosx], dim=_dim
    ).reshape(_view)

    rot_mat = xmat @ ymat @ zmat
    # print(rot_mat)
    return rot_mat


def points_to_2d_grid(
    points, grid_h=params["grid_height"], grid_w=params["grid_width"]
):
    """
    Converts a point cloud into a 2D grid based on X, Y coordinates.
    Points are projected onto a plane and quantized into grid cells.

    Args:
        points (torch.tensor): Tensor containing points, shape [B, P, 3]
                               (B: batch size, P: number of points, 3: x, y, z coordinates)
        grid_h (int): Height of the output 2D grid.
        grid_w (int): Width of the output 2D grid.

    Returns:
        grid (torch.tensor): 2D grid representing the occupancy of points,
                             shape [B, grid_h, grid_w].
                             Value 1.0 at cell (y, x) if at least one point falls into it,
                             otherwise the background value (params["bg_clr"]).
    """
    batch, pnum, _ = points.shape
    device = points.device

    # --- Step 1: Normalize point coordinates ---
    # Find min/max for each point cloud in the batch (considering only X, Y for better 2D normalization)
    pmax_xy = points[:, :, :2].max(dim=1)[0]
    pmin_xy = points[:, :, :2].min(dim=1)[0]

    # Compute the center and range based on X, Y
    pcent_xy = (pmax_xy + pmin_xy) / 2
    pcent_xy = pcent_xy[:, None, :]  # Add P dimension for broadcasting [B, 1, 2]

    # Use the larger range between X and Y to maintain aspect ratio
    prange_xy = (pmax_xy - pmin_xy).max(dim=-1)[0][:, None, None]  # [B, 1, 1]

    # Add a small epsilon to avoid division by zero if all points overlap
    epsilon = 1e-8
    # Normalize X, Y into the range [-1, 1] based on the X, Y range
    points_normalized_xy = (points[:, :, :2] - pcent_xy) / (prange_xy + epsilon) * 2.0

    # Adjust the scale according to obj_ratio (if needed)
    points_normalized_xy = points_normalized_xy * params["obj_ratio"]

    # --- Step 2: Map normalized coordinates to 2D grid indices ---
    # Map X from the range [-obj_ratio, obj_ratio] -> [0, grid_w]
    # Map Y from the range [-obj_ratio, obj_ratio] -> [0, grid_h]
    # General formula: (normalized_coord + scale) / (2 * scale) * grid_dim
    _x = (
        (points_normalized_xy[:, :, 0] + params["obj_ratio"])
        / (2 * params["obj_ratio"])
        * grid_w
    )
    _y = (
        (points_normalized_xy[:, :, 1] + params["obj_ratio"])
        / (2 * params["obj_ratio"])
        * grid_h
    )

    # Round down to determine the grid cell indices
    _x = torch.floor(_x).long()
    _y = torch.floor(_y).long()

    # --- Step 3: Clamp indices to valid grid range ---
    # Clip _x to [0, grid_w - 1]
    # Clip _y to [0, grid_h - 1]
    _x = torch.clip(_x, 0, grid_w - 1)
    _y = torch.clip(_y, 0, grid_h - 1)

    # --- Step 4: Create a 2D grid and mark occupied cells ---
    # Initialize the 2D grid with the background value
    grid = torch.full(
        (batch, grid_h, grid_w), params["bg_clr"], dtype=torch.float32, device=device
    )

    # Create batch indices corresponding to each point
    batch_indices = torch.arange(batch, device=device).view(-1, 1).repeat(1, pnum)

    # Flatten indices for easier assignment
    batch_idx_flat = batch_indices.view(-1)
    y_idx_flat = _y.view(-1)
    x_idx_flat = _x.view(-1)

    # Assign a value of 1.0 to grid cells (y, x) corresponding to point positions
    # If multiple points fall into the same cell, the cell still has a value of 1.0
    grid[batch_idx_flat, y_idx_flat, x_idx_flat] = 1.0

    return grid


def points2grid(points, resolution=params["resolution"], depth=params["depth"]):
    """Quantize each point cloud to a 3D grid.
    Args:
        points (torch.tensor): of size [B, _, 3]
    Returns:
        grid (torch.tensor): of size [B * self.num_views, depth, resolution, resolution]
    """

    batch, pnum, _ = points.shape

    pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
    pcent = (pmax + pmin) / 2
    pcent = pcent[:, None, :]
    prange = (pmax - pmin).max(dim=-1)[0][:, None, None]
    points = (points - pcent) / prange * 2.0
    points[:, :, :2] = points[:, :, :2] * params["obj_ratio"]

    depth_bias = params["depth_bias"]
    _x = (points[:, :, 0] + 1) / 2 * resolution
    _y = (points[:, :, 1] + 1) / 2 * resolution
    _z = ((points[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)

    _x.ceil_()
    _y.ceil_()
    z_int = _z.ceil()

    _x = torch.clip(_x, 1, resolution - 2)
    _y = torch.clip(_y, 1, resolution - 2)
    _z = torch.clip(_z, 1, depth - 2)

    coordinates = z_int * resolution * resolution + _y * resolution + _x
    grid = (
        torch.ones([batch, depth, resolution, resolution], device=points.device).view(
            batch, -1
        )
        * params["bg_clr"]
    )

    grid = scatter(_z, coordinates.long(), dim=1, out=grid, reduce="max")
    grid = grid.reshape((batch, depth, resolution, resolution)).permute((0, 1, 3, 2))

    return grid


def points_to_occupancy_grid(
    points, resolution=params["resolution"], depth=params["depth"]
):
    """Quantize each point cloud into a 3D occupancy grid."""

    batch, pnum, _ = points.shape
    device = points.device  # Get device to create new tensors

    # --- Normalization and coordinate mapping remain unchanged ---
    pmax, pmin = points.max(dim=1)[0], points.min(dim=1)[0]
    pcent = (pmax + pmin) / 2
    pcent = pcent[:, None, :]
    prange = (pmax - pmin).max(dim=-1)[0][
        :, None, None
    ] + 1e-8  # Add epsilon to avoid division by zero
    points_norm = (points - pcent) / prange * 2.0
    points_norm[:, :, :2] = points_norm[:, :, :2] * params["obj_ratio"]

    depth_bias = params["depth_bias"]
    _x = (points_norm[:, :, 0] + 1) / 2 * resolution
    _y = (points_norm[:, :, 1] + 1) / 2 * resolution
    _z = ((points_norm[:, :, 2] + 1) / 2 + depth_bias) / (1 + depth_bias) * (depth - 2)

    _x.ceil_()
    _y.ceil_()
    z_int = _z.ceil()

    _x = torch.clip(_x, 1, resolution - 2)
    _y = torch.clip(_y, 1, resolution - 2)
    # z_int should also be clipped if used as coordinate indices
    z_int = torch.clip(z_int, 1, depth - 2)

    # --- Compute flattened coordinates ---
    coordinates = z_int * resolution * resolution + _y * resolution + _x
    coordinates = coordinates.long()  # Convert to Long

    # --- Create Grid and Scatter ---
    # Initialize the grid with the background value (e.g., 0)
    # Use torch.zeros instead of torch.ones and multiply by bg_clr
    bg_clr_value = params.get("bg_clr", 0.0)  # Get bg_clr, default is 0
    grid = torch.full(
        (batch, depth * resolution * resolution),
        bg_clr_value,
        dtype=torch.float32,  # Or appropriate dtype
        device=device,
    )

    # Create a source tensor (src) containing a value of 1.0 for each point
    # The size must match the flattened coordinates: [B * pnum]
    values_to_scatter = torch.ones(batch * pnum, dtype=torch.float32, device=device)

    # Scatter the value 1.0 into the grid at the positions `coordinates`
    # Use reduce="max". If a cell has at least one point, max(1.0, bg_clr) will be 1.0 (if bg_clr <= 1)
    # To ensure the value is always 1 regardless of bg_clr, use a different reduce or post-process after scatter.
    # A safer choice if bg_clr can be > 1 is to initialize the grid with 0 and use reduce='max'/'mean'
    # Or initialize with bg_clr and process after scatter.
    if bg_clr_value != 0.0:
        print(
            "Warning: bg_clr is not 0.0, occupancy grid might not be strictly binary 0/1 with reduce='max'. Consider initializing grid with 0."
        )

    grid = scatter(
        values_to_scatter,
        coordinates.view(-1),  # Flatten coordinates to [B*pnum]
        dim=0,  # Scatter along dimension 0 of the flattened grid [B*D*R*R]
        out=grid.view(-1),  # Flatten grid to [B*D*R*R] for scatter along dim 0
        reduce="max",
    )  # If a point exists -> cell value is 1, otherwise bg_clr

    # --- Reshape and Permute remain unchanged ---
    # Reshape the grid back to the correct 3D + batch size
    # Note: scatter into a flattened grid requires careful reshaping
    grid = grid.view(batch, depth, resolution, resolution)  # Reshape back
    grid = grid.permute((0, 1, 3, 2))

    return grid


class Realistic_Projection:
    """For creating images from PC based on the view information."""

    def __init__(self):
        _views = np.asarray([
            [[1 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[3 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[5 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[7 * np.pi / 4, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[0 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[1 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[2 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[3 * np.pi / 2, 0, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[0, -np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
            [[0, np.pi / 2, np.pi / 2], [-0.5, -0.5, TRANS]],
        ])

        # adding some bias to the view angle to reveal more surface
        _views_bias = np.asarray([
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 9, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
            [[0, np.pi / 15, 0], [-0.5, 0, TRANS]],
        ])

        self.num_views = _views.shape[0]

        angle = torch.tensor(_views[:, 0, :]).float()  # .cuda()
        self.rot_mat = euler2mat(angle).transpose(1, 2)
        angle2 = torch.tensor(_views_bias[:, 0, :]).float()  # .cuda()
        self.rot_mat2 = euler2mat(angle2).transpose(1, 2)

        self.translation = torch.tensor(_views[:, 1, :]).float()  # .cuda()
        self.translation = self.translation.unsqueeze(1)

        self.grid2image = Grid2Image()  # .cuda()

    def get_img(self, points):
        b, _, _ = points.shape
        v = self.translation.shape[0]

        _points = self.point_transform(
            points=torch.repeat_interleave(points, v, dim=0),
            rot_mat=self.rot_mat.repeat(b, 1, 1),
            rot_mat2=self.rot_mat2.repeat(b, 1, 1),
            translation=self.translation.repeat(b, 1, 1),
        )

        grid = points2grid(
            points=_points, resolution=params["resolution"], depth=params["depth"]
        ).squeeze()
        img = self.grid2image(grid)
        return img

    @staticmethod
    def point_transform(points, rot_mat, rot_mat2, translation):
        """
        :param points: [batch, num_points, 3]
        :param rot_mat: [batch, 3]
        :param rot_mat2: [batch, 3]
        :param translation: [batch, 1, 3]
        :return:
        """
        rot_mat = rot_mat.to(points.device)
        rot_mat2 = rot_mat2.to(points.device)
        translation = translation.to(points.device)
        points = torch.matmul(points, rot_mat)
        points = torch.matmul(points, rot_mat2)
        points = points - translation
        return points


def get2DGaussianKernel(ksize, sigma=0):
    center = ksize // 2
    xs = np.arange(ksize, dtype=np.float32) - center
    kernel1d = np.exp(-(xs**2) / (2 * sigma**2))
    kernel = kernel1d[..., None] @ kernel1d[None, ...]
    kernel = torch.from_numpy(kernel)
    kernel = kernel / kernel.sum()
    return kernel


# Without numpy
# def get2DGaussianKernel(ksize, sigma):
#     xs = torch.linspace(-(ksize // 2), ksize // 2, steps=ksize)
#     kernel1d = torch.exp(-(xs ** 2) / (2 * sigma ** 2))
#     kernel2d = torch.outer(kernel1d, kernel1d)
#     kernel2d /= kernel2d.sum()
#     return kernel2d


def get3DGaussianKernel(ksize, depth, sigma=2, zsigma=2):
    kernel2d = get2DGaussianKernel(ksize, sigma)
    zs = np.arange(depth, dtype=np.float32) - depth // 2
    zkernel = np.exp(-(zs**2) / (2 * zsigma**2))
    kernel3d = np.repeat(kernel2d[None, :, :], depth, axis=0) * zkernel[:, None, None]
    kernel3d = kernel3d / torch.sum(kernel3d)
    return kernel3d