File size: 12,632 Bytes
57276d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import cv2
import numpy as np
import torch
import utils3d
from PIL import Image

from moge.model.v1 import MoGeModel
from moge.utils.panorama import (
    get_panorama_cameras,
    split_panorama_image,
    merge_panorama_depth,
)
from .general_utils import spherical_uv_to_directions


# from https://github.com/lpiccinelli-eth/UniK3D/unik3d/utils/coordinate.py
def coords_grid(b, h, w):
    r"""
    Generate a grid of pixel coordinates in the range [0.5, W-0.5] and [0.5, H-0.5].
    Args:
        b (int): Batch size.
        h (int): Height of the grid.
        w (int): Width of the grid.
    Returns:
        grid (torch.Tensor): A tensor of shape [B, 2, H, W] containing the pixel coordinates.
    """
    # Create pixel coordinates in the range [0.5, W-0.5] and [0.5, H-0.5]
    pixel_coords_x = torch.linspace(0.5, w - 0.5, w)
    pixel_coords_y = torch.linspace(0.5, h - 0.5, h)

    # Stack the pixel coordinates to create a grid
    stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]

    grid = torch.stack(stacks, dim=0).float()  # [2, H, W] or [3, H, W]
    grid = grid[None].repeat(b, 1, 1, 1)  # [B, 2, H, W] or [B, 3, H, W]

    return grid


def build_depth_model(device: torch.device = "cuda"):
    r"""
    Build the MoGe depth model for panorama depth prediction.
    Args:
        device (torch.device): The device to load the model onto (e.g., "cuda" or "cpu").
    Returns:
        model (MoGeModel): The MoGe depth model instance.
    """
    # Load model from pretrained weights
    model = MoGeModel.from_pretrained("Ruicheng/moge-vitl")
    model.eval()
    model = model.to(device)
    return model


def smooth_south_pole_depth(depth_map, smooth_height_ratio=0.03, lower_quantile=0.1, upper_quantile=0.9):
    """
    Smooth depth values at the south pole (bottom) of a panorama to address inconsistencies.
    Args:
        depth_map (np.ndarray): Input depth map, shape (H, W).
        smooth_height_ratio (float): Ratio of the height to smooth, typically a small value like 0.03.
        lower_quantile (float): The lower quantile for outlier filtering.
        upper_quantile (float): The upper quantile for outlier filtering.
    Returns:
        np.ndarray: Smoothed depth map.
    """
    height, width = depth_map.shape
    smooth_height = int(height * smooth_height_ratio)

    if smooth_height == 0:
        return depth_map

    # Create copy to avoid modifying original
    smoothed_depth = depth_map.copy()

    # Calculate reference depth from bottom rows:
    # When the number of rows is greater than 3, use the last 3 rows; otherwise, use the bottom row
    if smooth_height > 3:
        # Calculate the average depth using the last 3 rows
        reference_rows = depth_map[-3:, :]
        reference_data = reference_rows.flatten()
    else:
        # Use the bottom row
        reference_data = depth_map[-1, :]

    # Filter outliers: including invalid values, depth that is too large or too small
    valid_mask = np.isfinite(reference_data) & (reference_data > 0)

    if np.any(valid_mask):
        valid_depths = reference_data[valid_mask]

        # Use quantiles to filter extreme outliers.
        lower_bound, upper_bound = np.quantile(valid_depths, [lower_quantile, upper_quantile])

        # Further filter out depth values that are too large or too small
        depth_filter_mask = (valid_depths >= lower_bound) & (
            valid_depths <= upper_bound
        )

        if np.any(depth_filter_mask):
            avg_depth = np.mean(valid_depths[depth_filter_mask])
        else:
            # If all values are filtered out, use the median as an alternative
            avg_depth = np.median(valid_depths)
    else:
        avg_depth = np.nanmean(reference_data)

    # Set the bottom row as the average value
    smoothed_depth[-1, :] = avg_depth

    # Smooth upwards to the specified height
    for i in range(1, smooth_height):
        y_idx = height - 1 - i  # Index from bottom to top
        if y_idx < 0:
            break

        # Calculate smoothness weight: The closer to the bottom, the stronger the smoothness
        weight = (smooth_height - i) / smooth_height

        # Smooth the current row
        current_row = depth_map[y_idx, :]
        valid_mask = np.isfinite(current_row) & (current_row > 0)

        if np.any(valid_mask):
            valid_row_depths = current_row[valid_mask]

            # Apply outlier filtering to the current row as well
            if len(valid_row_depths) > 1:
                q25, q75 = np.quantile(valid_row_depths, [0.25, 0.75])
                iqr = q75 - q25
                lower_bound = q25 - 1.5 * iqr
                upper_bound = q75 + 1.5 * iqr
                depth_filter_mask = (valid_row_depths >= lower_bound) & (
                    valid_row_depths <= upper_bound
                )

                if np.any(depth_filter_mask):
                    row_avg = np.mean(valid_row_depths[depth_filter_mask])
                else:
                    row_avg = np.median(valid_row_depths)
            else:
                row_avg = (
                    valid_row_depths[0] if len(valid_row_depths) > 0 else avg_depth
                )

            # Linear interpolation: between the original depth and the average depth
            smoothed_depth[y_idx, :] = (1 - weight) * current_row + weight * row_avg

    return smoothed_depth


def pred_pano_depth(
    model,
    image: Image.Image,
    img_name: str,
    scale=1.0,
    resize_to=1920,
    remove_pano_depth_nan=True,
    last_layer_mask=None,
    last_layer_depth=None,
    verbose=False,
) -> dict:
    r"""
    Predict panorama depth using the MoGe model.
    Args:
        model (MoGeModel): The MoGe depth model instance.
        image (Image.Image): Input panorama image.
        img_name (str): Name of the image for saving outputs.
        scale (float): Scale factor for resizing the image.
        resize_to (int): Target size for resizing the image.
        remove_pano_depth_nan (bool): Whether to remove NaN values from the predicted depth.
        last_layer_mask (np.ndarray, optional): Mask from the last layer for inpainting.
        last_layer_depth (dict, optional): Last layer depth information containing distance maps and masks.
        verbose (bool): Whether to print verbose information.
    Returns:
        dict: A dictionary containing the predicted depth maps and masks.
    """
    if verbose:
        print("\t - Predicting pano depth with moge")

    # Process input image
    image_origin = np.array(image)
    height_origin, width_origin = image_origin.shape[:2]
    image, height, width = image_origin, height_origin, width_origin

    # Resize if needed
    if resize_to is not None:
        _height, _width = min(
            resize_to, int(resize_to * height_origin / width_origin)
        ), min(resize_to, int(resize_to * width_origin / height_origin))
        if _height < height_origin:
            if verbose:
                print(
                    f"\t - Resizing image from {width_origin}x{height_origin} \
                    to {_width}x{_height} for pano depth prediction"
                )
            image = cv2.resize(image_origin, (_width, _height), cv2.INTER_AREA)
            height, width = _height, _width
    # Split panorama into multiple views
    splitted_extrinsics, splitted_intriniscs = get_panorama_cameras()
    splitted_resolution = 512
    splitted_images = split_panorama_image(
        image, splitted_extrinsics, splitted_intriniscs, splitted_resolution
    )

    # Handle inpainting masks if provided
    splitted_inpaint_masks = None
    if last_layer_mask is not None and last_layer_depth is not None:
        splitted_inpaint_masks = split_panorama_image(
            last_layer_mask,
            splitted_extrinsics,
            splitted_intriniscs,
            splitted_resolution,
        )

    # infer moge depth
    num_splitted_images = len(splitted_images)
    splitted_distance_maps = [None] * num_splitted_images
    splitted_masks = [None] * num_splitted_images

    indices_to_process_model = []
    skipped_count = 0

    # Determine which images need processing
    for i in range(num_splitted_images):
        if splitted_inpaint_masks is not None and splitted_inpaint_masks[i].sum() == 0:
            # Use depth from the previous layer for non-inpainted (masked) regions
            splitted_distance_maps[i] = last_layer_depth["splitted_distance_maps"][i]
            splitted_masks[i] = last_layer_depth["splitted_masks"][i]
            skipped_count += 1
        else:
            indices_to_process_model.append(i)

    pred_count = 0
    # Process images that require model inference in batches
    inference_batch_size = 1
    for i in range(0, len(indices_to_process_model), inference_batch_size):
        batch_indices = indices_to_process_model[i : i + inference_batch_size]
        if not batch_indices:
            continue
        # Prepare batch
        current_batch_images = [splitted_images[k] for k in batch_indices]
        current_batch_intrinsics = [splitted_intriniscs[k] for k in batch_indices]
        # Convert to tensor and normalize
        image_tensor = torch.tensor(
            np.stack(current_batch_images) / 255,
            dtype=torch.float32,
            device=next(model.parameters()).device,
        ).permute(0, 3, 1, 2)
        # Calculate field of view
        fov_x, _ = np.rad2deg(  # fov_y is not used by model.infer
            utils3d.numpy.intrinsics_to_fov(np.array(current_batch_intrinsics))
        )
        fov_x_tensor = torch.tensor(
            fov_x, dtype=torch.float32, device=next(model.parameters()).device
        )
        # Run inference
        output = model.infer(image_tensor, fov_x=fov_x_tensor, apply_mask=False)

        batch_distance_maps = output["points"].norm(dim=-1).cpu().numpy()
        batch_masks = output["mask"].cpu().numpy()
        # Store results
        for batch_idx, original_idx in enumerate(batch_indices):
            splitted_distance_maps[original_idx] = batch_distance_maps[batch_idx]
            splitted_masks[original_idx] = batch_masks[batch_idx]
            pred_count += 1

    if verbose:
        # Print processing statistics
        if (
            pred_count + skipped_count
        ) == 0:  # Avoid division by zero if num_splitted_images is 0
            skip_ratio_info = "N/A (no images to process)"
        else:
            skip_ratio_info = f"{skipped_count / (pred_count + skipped_count):.2%}"
        print(
            f"\t 🔍 Predicted {pred_count} splitted images, \
            skipped {skipped_count} splitted images. Skip ratio: {skip_ratio_info}"
        )

    # merge moge depth
    merging_width, merging_height = width, height
    panorama_depth, panorama_mask = merge_panorama_depth(
        merging_width,
        merging_height,
        splitted_distance_maps,
        splitted_masks,
        splitted_extrinsics,
        splitted_intriniscs,
    )
    # Post-process depth map
    panorama_depth = panorama_depth.astype(np.float32)
    # Align the depth of the bottom 0.03 area on both sides of the dano depth
    if remove_pano_depth_nan:
        # for depth inpainting, remove nan
        panorama_depth[~panorama_mask] = 1.0 * np.nanquantile(
            panorama_depth, 0.999
        )  # sky depth
    panorama_depth = cv2.resize(
        panorama_depth, (width_origin, height_origin), cv2.INTER_LINEAR
    )
    panorama_mask = (
        cv2.resize(
            panorama_mask.astype(np.uint8),
            (width_origin, height_origin),
            cv2.INTER_NEAREST,
        )
        > 0
    )

    # Smooth the depth of the South Pole (bottom area) to solve the problem of left and right inconsistency
    if img_name in ["background", "full_img"]:
        if verbose:
            print("\t - Smoothing south pole depth for consistency")
        panorama_depth = smooth_south_pole_depth(
            panorama_depth, smooth_height_ratio=0.05
        )

    rays = torch.from_numpy(
        spherical_uv_to_directions(
            utils3d.numpy.image_uv(width=width_origin, height=height_origin)
        )
    ).to(next(model.parameters()).device)

    panorama_depth = (
        torch.from_numpy(panorama_depth).to(next(model.parameters()).device) * scale
    )

    return {
        "type": "",
        "rgb": torch.from_numpy(image_origin).to(next(model.parameters()).device),
        "distance": panorama_depth,
        "rays": rays,
        "mask": panorama_mask,
        "splitted_masks": splitted_masks,
        "splitted_distance_maps": splitted_distance_maps,
    }