File size: 13,653 Bytes
91126af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import torch
import torch.nn as nn
import numpy as np
import mast3r.utils.path_to_dust3r  # noqa
from dust3r.utils.geometry import normalize_pointcloud, xy_grid
import os
import torch.nn.functional as F
import cv2
from pytorch3d.ops import knn_points


class MeshOutput():
    def __init__(self, sam=False):
        self.sam = sam
    
    def __call__(self, gt1, gt2, pred1, pred2, trajectory_pred, render_gt=None, **kw):
        pts3d = pred2['pts3d']
        conf = pred2['conf']
        pts3d = pts3d.detach().cpu()
        B, N, H, W, _ = pts3d.shape
        thres = torch.quantile(conf.flatten(2,3), 0.1, dim=-1)[0]
        masks_conf = conf > thres[None, :, None, None]
        masks_conf = masks_conf.cpu()
        
        images = [view['img_org'] for view in gt1+gt2]
        shape = torch.stack([view['true_shape'] for view in gt1+gt2], dim=1)
        images = torch.stack(images,1).permute(0,1,3,4,2).detach().cpu().numpy()
        images = images / 2 + 0.5
        images = images.reshape(B, N, H, W, 3)
        outfile = os.path.join('./output/mesh', gt1[0]['instance'][0].split('/')[-1])
        outfile = outfile
        os.makedirs(outfile, exist_ok=True)
        
        # estimate focal length
        images = images[0]
        pts3d = pts3d[0]
        masks_conf = masks_conf[0]
        xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0)  # homogeneous (x,y,1)
        pp = torch.tensor((W/2, H/2)).to(xy_over_z)
        pixels = xy_grid(W, H, device=xy_over_z.device).view(1, -1, 2) - pp.view(-1, 1, 2)  # B,HW,2
        u, v = pixels[:1].unbind(dim=-1)
        x, y, z = pts3d[:1].reshape(-1,3).unbind(dim=-1)
        fx_votes = (u * z) / x
        fy_votes = (v * z) / y
        # assume square pixels, hence same focal for X and Y
        f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
        focal = torch.nanmedian(f_votes, dim=-1).values
        focal = focal.item()
        pts3d = pts3d.numpy()
        # use PNP to estimate camera poses
        pred_poses = []
        for i in range(pts3d.shape[0]):
            shape_input_each = shape[:, i]
            mesh_grid = xy_grid(shape_input_each[0,1], shape_input_each[0,0])
            cur_inlier = conf[0,i] > torch.quantile(conf[0,i], 0.6)
            cur_inlier = cur_inlier.detach().cpu().numpy()
            ransac_thres = 0.5
            confidence = 0.9999
            iterationsCount = 10_000
            cur_pts3d = pts3d[i]
            K = np.float32([(focal, 0, W/2), (0, focal, H/2), (0, 0, 1)])
            success, r_pose, t_pose, _ = cv2.solvePnPRansac(cur_pts3d[cur_inlier].astype(np.float64), mesh_grid[cur_inlier].astype(np.float64), K, None,
                                                            flags=cv2.SOLVEPNP_SQPNP,
                                                            iterationsCount=iterationsCount,
                                                            reprojectionError=1,
                                                            confidence=confidence)
            r_pose = cv2.Rodrigues(r_pose)[0]  
            RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]]
            cam2world = np.linalg.inv(RT)
            pred_poses.append(cam2world)
        pred_poses = np.stack(pred_poses, axis=0)
        pred_poses = torch.tensor(pred_poses)


        # use sam to segment the sky region
        if self.sam:
            from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
            from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
            from sam.scripts.segformer import segformer_segmentation as segformer_func
            from sam.scripts.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
            import pycocotools.mask as maskUtils

            os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
            sam = sam_model_registry["vit_h"](checkpoint='/data0/zsz/mast3recon/checkpoints/sam_vit_h_4b8939.pth').cuda()
            semantic_branch_processor = SegformerFeatureExtractor.from_pretrained(
            "nvidia/segformer-b5-finetuned-ade-640-640")
            semantic_branch_model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b5-finetuned-ade-640-640").to(pts2)
            id2label = CONFIG_ADE20K_ID2LABEL
            mask_branch_model = SamAutomaticMaskGenerator(
                model=sam,
                points_per_side=64,
                # Foggy driving (zero-shot evaluate) is more challenging than other dataset, so we use a larger points_per_side
                pred_iou_thresh=0.86,
                stability_score_thresh=0.92,
                crop_n_layers=1,
                crop_n_points_downscale_factor=2,
                min_mask_region_area=100,  # Requires open-cv to run post-processing
                output_mode='coco_rle',
            )
            class_masks = []
            for img in images:
                anns = {'annotations': mask_branch_model.generate(img)}
                class_ids = segformer_func((img * 255).astype(np.uint8), semantic_branch_processor, semantic_branch_model, 'cuda')
                semantc_mask = class_ids.clone()
                anns['annotations'] = sorted(anns['annotations'], key=lambda x: x['area'], reverse=True)
                bitmasks, class_names = [], []

                for ann in anns['annotations']:
                    valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
                    # get the class ids of the valid pixels
                    propose_classes_ids = class_ids[valid_mask]
                    num_class_proposals = len(torch.unique(propose_classes_ids))
                    if num_class_proposals == 1:
                        semantc_mask[valid_mask] = propose_classes_ids[0]
                        ann['class_name'] = id2label['id2label'][str(propose_classes_ids[0].item())]
                        ann['class_proposals'] = id2label['id2label'][str(propose_classes_ids[0].item())]
                        class_names.append(ann['class_name'])
                        # bitmasks.append(maskUtils.decode(ann['segmentation']))
                        continue
                    top_1_propose_class_ids = torch.bincount(propose_classes_ids.flatten()).topk(1).indices
                    top_1_propose_class_names = [id2label['id2label'][str(class_id.item())] for class_id in top_1_propose_class_ids]

                    semantc_mask[valid_mask] = top_1_propose_class_ids
                    ann['class_name'] = top_1_propose_class_names[0]
                    ann['class_proposals'] = top_1_propose_class_names[0]
                    class_names.append(ann['class_name'])
                    # bitmasks.append(maskUtils.decode(ann['segmentation']))

                    del valid_mask
                    del propose_classes_ids
                    del num_class_proposals
                    del top_1_propose_class_ids
                    del top_1_propose_class_names
                
                sematic_class_in_img = torch.unique(semantc_mask)
                semantic_bitmasks, semantic_class_names = [], []

                # semantic prediction
                anns['semantic_mask'] = {}
                flag = False
                for i in range(len(sematic_class_in_img)):
                    class_name = id2label['id2label'][str(sematic_class_in_img[i].item())]
                    if class_name != 'sky':
                        continue
                    flag = True             
                    class_mask = semantc_mask == sematic_class_in_img[i]
                    # class_mask = class_mask.cpu().numpy().astype(np.uint8)
                    class_masks.append(class_mask)
                if flag == False:
                    class_mask = torch.zeros_like(semantc_mask) > 0 
                    class_masks.append(class_mask)
            class_masks = torch.stack(class_masks, 0)
            class_masks = ~class_masks
        else:
            class_masks = torch.ones_like(masks_conf) > 0

        # use knn to clean the point cloud
        K = 10
        points = torch.tensor(pts3d.reshape(1,-1,3)).cuda()
        knn = knn_points(points, points, K=K)
        dists = knn.dists  
        mean_dists = dists.mean(dim=-1)
        masks_dist = mean_dists < torch.quantile(mean_dists.reshape(-1), 0.95)
        masks_dist = masks_dist.detach().cpu().numpy()
        
        import nerfvis.scene as scene_vis
        scene_vis.set_opencv()
        masks_conf = (masks_conf > 0) & masks_dist.reshape(-1,H,W) & class_masks.reshape(-1,H,W)
        masks_conf = masks_conf > 0
        filtered_points = pts3d[masks_conf].reshape(-1, 3)
        colors = images[masks_conf].reshape(-1, 3)
        scene_vis.add_points('points', filtered_points.reshape(-1,3), vert_color=colors.reshape(-1,3), size = 1)
        scene_vis.add_images(
            f"images/i",
            images, # Can be a list of paths too (requires joblib for that) 
            r=pred_poses[:, :3, :3],
            t=pred_poses[:, :3, 3],
            focal_length=focal,
            z=0.1,
            with_camera_frustum=True,
        )
        np.savez(outfile + '/pred.npz', pts3d=pts3d, vert_color=images, poses=pred_poses.detach().cpu(), intrinsic=focal, images=images, mask = masks_conf)
        print(f"save {outfile}")
        save_content = 'CUDA_VISIBLE_DEVICES=1 python visualizer/run_vis.py --result_npz {} --results_folder {}'.format(outfile + '/pred.npz', outfile)
        file_path = outfile + '/run_vis.sh'
        os.system(save_content)
        print(f"run {file_path} to visualize geometry and poses")
        with open(file_path, 'a') as file:
            file.write(save_content + '\n') 
        scene_vis.display(port=8828)
        return None, None
    

def interpolate_pose(pose1, pose2, t):
    """
    Interpolate between two camera poses (4x4 matrices).

    :param pose1: First pose (4x4 matrix)
    :param pose2: Second pose (4x4 matrix)
    :param t: Interpolation factor, t in [0, 1]
    :return: Interpolated pose (4x4 matrix)
    """

    # Extract translation and rotation from both poses
    translation1 = pose1[:3, 3].detach().cpu().numpy()
    translation2 = pose2[:3, 3].detach().cpu().numpy()
    rotation1 = pose1[:3, :3].detach().cpu().numpy()
    rotation2 = pose2[:3, :3].detach().cpu().numpy()

    # Interpolate the translation (linear interpolation)
    interpolated_translation = (1 - t) * translation1 + t * translation2
    
    # Convert rotation matrices to quaternions
    quat1 = R.from_matrix(rotation1).as_quat()
    quat2 = R.from_matrix(rotation2).as_quat()

    # Slerp for rotation interpolation
    slerp = Slerp([0, 1], R.from_quat([quat1, quat2]))

    interpolated_rotation = slerp(t).as_matrix()
    # Combine the interpolated rotation and translation
    interpolated_pose = np.eye(4)
    interpolated_pose[:3, :3] = interpolated_rotation
    interpolated_pose[:3, 3] = interpolated_translation
    return interpolated_pose

def colorize(value, vmin=None, vmax=None, cmap='magma_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
    """Converts a depth map to a color image.

    Args:
        value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
        vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
        vmax (float, optional):  vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
        cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
        invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
        invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
        background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
        gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
        value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.

    Returns:
        numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
    """
    if isinstance(value, torch.Tensor):
        value = value.detach().cpu().numpy()

    value = value.squeeze()
    if invalid_mask is None:
        invalid_mask = value == invalid_val
    mask = np.logical_not(invalid_mask)

    # normalize
    vmin = np.percentile(value[mask],2) if vmin is None else vmin
    vmax = np.percentile(value[mask],85) if vmax is None else vmax
    if vmin != vmax:
        value = (value - vmin) / (vmax - vmin)  # vmin..vmax
    else:
        # Avoid 0-division
        value = value * 0.

    # squeeze last dim if it exists
    # grey out the invalid values

    value[invalid_mask] = np.nan
    import matplotlib
    cmapper = matplotlib.cm.get_cmap(cmap)
    if value_transform:
        value = value_transform(value)
        # value = value / value.max()
    value = cmapper(value, bytes=True)  # (nxmx4)

    # img = value[:, :, :]
    img = value[...]
    img[invalid_mask] = background_color

    #     return img.transpose((2, 0, 1))
    if gamma_corrected:
        # gamma correction
        img = img / 255
        img = np.power(img, 2.2)
        img = img * 255
        img = img.astype(np.uint8)
    img = torch.from_numpy(img)/255.
    return img.permute(2,0,1)[:3]