import cv2 import numpy as np import torch import utils3d from PIL import Image from moge.model.v1 import MoGeModel from moge.utils.panorama import ( get_panorama_cameras, split_panorama_image, merge_panorama_depth, ) from .general_utils import spherical_uv_to_directions # from https://github.com/lpiccinelli-eth/UniK3D/unik3d/utils/coordinate.py def coords_grid(b, h, w): r""" Generate a grid of pixel coordinates in the range [0.5, W-0.5] and [0.5, H-0.5]. Args: b (int): Batch size. h (int): Height of the grid. w (int): Width of the grid. Returns: grid (torch.Tensor): A tensor of shape [B, 2, H, W] containing the pixel coordinates. """ # Create pixel coordinates in the range [0.5, W-0.5] and [0.5, H-0.5] pixel_coords_x = torch.linspace(0.5, w - 0.5, w) pixel_coords_y = torch.linspace(0.5, h - 0.5, h) # Stack the pixel coordinates to create a grid stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()] grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] return grid def build_depth_model(device: torch.device = "cuda"): r""" Build the MoGe depth model for panorama depth prediction. Args: device (torch.device): The device to load the model onto (e.g., "cuda" or "cpu"). Returns: model (MoGeModel): The MoGe depth model instance. """ # Load model from pretrained weights model = MoGeModel.from_pretrained("Ruicheng/moge-vitl") model.eval() model = model.to(device) return model def smooth_south_pole_depth(depth_map, smooth_height_ratio=0.03, lower_quantile=0.1, upper_quantile=0.9): """ Smooth depth values at the south pole (bottom) of a panorama to address inconsistencies. Args: depth_map (np.ndarray): Input depth map, shape (H, W). smooth_height_ratio (float): Ratio of the height to smooth, typically a small value like 0.03. lower_quantile (float): The lower quantile for outlier filtering. upper_quantile (float): The upper quantile for outlier filtering. Returns: np.ndarray: Smoothed depth map. """ height, width = depth_map.shape smooth_height = int(height * smooth_height_ratio) if smooth_height == 0: return depth_map # Create copy to avoid modifying original smoothed_depth = depth_map.copy() # Calculate reference depth from bottom rows: # When the number of rows is greater than 3, use the last 3 rows; otherwise, use the bottom row if smooth_height > 3: # Calculate the average depth using the last 3 rows reference_rows = depth_map[-3:, :] reference_data = reference_rows.flatten() else: # Use the bottom row reference_data = depth_map[-1, :] # Filter outliers: including invalid values, depth that is too large or too small valid_mask = np.isfinite(reference_data) & (reference_data > 0) if np.any(valid_mask): valid_depths = reference_data[valid_mask] # Use quantiles to filter extreme outliers. lower_bound, upper_bound = np.quantile(valid_depths, [lower_quantile, upper_quantile]) # Further filter out depth values that are too large or too small depth_filter_mask = (valid_depths >= lower_bound) & ( valid_depths <= upper_bound ) if np.any(depth_filter_mask): avg_depth = np.mean(valid_depths[depth_filter_mask]) else: # If all values are filtered out, use the median as an alternative avg_depth = np.median(valid_depths) else: avg_depth = np.nanmean(reference_data) # Set the bottom row as the average value smoothed_depth[-1, :] = avg_depth # Smooth upwards to the specified height for i in range(1, smooth_height): y_idx = height - 1 - i # Index from bottom to top if y_idx < 0: break # Calculate smoothness weight: The closer to the bottom, the stronger the smoothness weight = (smooth_height - i) / smooth_height # Smooth the current row current_row = depth_map[y_idx, :] valid_mask = np.isfinite(current_row) & (current_row > 0) if np.any(valid_mask): valid_row_depths = current_row[valid_mask] # Apply outlier filtering to the current row as well if len(valid_row_depths) > 1: q25, q75 = np.quantile(valid_row_depths, [0.25, 0.75]) iqr = q75 - q25 lower_bound = q25 - 1.5 * iqr upper_bound = q75 + 1.5 * iqr depth_filter_mask = (valid_row_depths >= lower_bound) & ( valid_row_depths <= upper_bound ) if np.any(depth_filter_mask): row_avg = np.mean(valid_row_depths[depth_filter_mask]) else: row_avg = np.median(valid_row_depths) else: row_avg = ( valid_row_depths[0] if len(valid_row_depths) > 0 else avg_depth ) # Linear interpolation: between the original depth and the average depth smoothed_depth[y_idx, :] = (1 - weight) * current_row + weight * row_avg return smoothed_depth def pred_pano_depth( model, image: Image.Image, img_name: str, scale=1.0, resize_to=1920, remove_pano_depth_nan=True, last_layer_mask=None, last_layer_depth=None, verbose=False, ) -> dict: r""" Predict panorama depth using the MoGe model. Args: model (MoGeModel): The MoGe depth model instance. image (Image.Image): Input panorama image. img_name (str): Name of the image for saving outputs. scale (float): Scale factor for resizing the image. resize_to (int): Target size for resizing the image. remove_pano_depth_nan (bool): Whether to remove NaN values from the predicted depth. last_layer_mask (np.ndarray, optional): Mask from the last layer for inpainting. last_layer_depth (dict, optional): Last layer depth information containing distance maps and masks. verbose (bool): Whether to print verbose information. Returns: dict: A dictionary containing the predicted depth maps and masks. """ if verbose: print("\t - Predicting pano depth with moge") # Process input image image_origin = np.array(image) height_origin, width_origin = image_origin.shape[:2] image, height, width = image_origin, height_origin, width_origin # Resize if needed if resize_to is not None: _height, _width = min( resize_to, int(resize_to * height_origin / width_origin) ), min(resize_to, int(resize_to * width_origin / height_origin)) if _height < height_origin: if verbose: print( f"\t - Resizing image from {width_origin}x{height_origin} \ to {_width}x{_height} for pano depth prediction" ) image = cv2.resize(image_origin, (_width, _height), cv2.INTER_AREA) height, width = _height, _width # Split panorama into multiple views splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() splitted_resolution = 512 splitted_images = split_panorama_image( image, splitted_extrinsics, splitted_intriniscs, splitted_resolution ) # Handle inpainting masks if provided splitted_inpaint_masks = None if last_layer_mask is not None and last_layer_depth is not None: splitted_inpaint_masks = split_panorama_image( last_layer_mask, splitted_extrinsics, splitted_intriniscs, splitted_resolution, ) # infer moge depth num_splitted_images = len(splitted_images) splitted_distance_maps = [None] * num_splitted_images splitted_masks = [None] * num_splitted_images indices_to_process_model = [] skipped_count = 0 # Determine which images need processing for i in range(num_splitted_images): if splitted_inpaint_masks is not None and splitted_inpaint_masks[i].sum() == 0: # Use depth from the previous layer for non-inpainted (masked) regions splitted_distance_maps[i] = last_layer_depth["splitted_distance_maps"][i] splitted_masks[i] = last_layer_depth["splitted_masks"][i] skipped_count += 1 else: indices_to_process_model.append(i) pred_count = 0 # Process images that require model inference in batches inference_batch_size = 1 for i in range(0, len(indices_to_process_model), inference_batch_size): batch_indices = indices_to_process_model[i : i + inference_batch_size] if not batch_indices: continue # Prepare batch current_batch_images = [splitted_images[k] for k in batch_indices] current_batch_intrinsics = [splitted_intriniscs[k] for k in batch_indices] # Convert to tensor and normalize image_tensor = torch.tensor( np.stack(current_batch_images) / 255, dtype=torch.float32, device=next(model.parameters()).device, ).permute(0, 3, 1, 2) # Calculate field of view fov_x, _ = np.rad2deg( # fov_y is not used by model.infer utils3d.numpy.intrinsics_to_fov(np.array(current_batch_intrinsics)) ) fov_x_tensor = torch.tensor( fov_x, dtype=torch.float32, device=next(model.parameters()).device ) # Run inference output = model.infer(image_tensor, fov_x=fov_x_tensor, apply_mask=False) batch_distance_maps = output["points"].norm(dim=-1).cpu().numpy() batch_masks = output["mask"].cpu().numpy() # Store results for batch_idx, original_idx in enumerate(batch_indices): splitted_distance_maps[original_idx] = batch_distance_maps[batch_idx] splitted_masks[original_idx] = batch_masks[batch_idx] pred_count += 1 if verbose: # Print processing statistics if ( pred_count + skipped_count ) == 0: # Avoid division by zero if num_splitted_images is 0 skip_ratio_info = "N/A (no images to process)" else: skip_ratio_info = f"{skipped_count / (pred_count + skipped_count):.2%}" print( f"\t 🔍 Predicted {pred_count} splitted images, \ skipped {skipped_count} splitted images. Skip ratio: {skip_ratio_info}" ) # merge moge depth merging_width, merging_height = width, height panorama_depth, panorama_mask = merge_panorama_depth( merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs, ) # Post-process depth map panorama_depth = panorama_depth.astype(np.float32) # Align the depth of the bottom 0.03 area on both sides of the dano depth if remove_pano_depth_nan: # for depth inpainting, remove nan panorama_depth[~panorama_mask] = 1.0 * np.nanquantile( panorama_depth, 0.999 ) # sky depth panorama_depth = cv2.resize( panorama_depth, (width_origin, height_origin), cv2.INTER_LINEAR ) panorama_mask = ( cv2.resize( panorama_mask.astype(np.uint8), (width_origin, height_origin), cv2.INTER_NEAREST, ) > 0 ) # Smooth the depth of the South Pole (bottom area) to solve the problem of left and right inconsistency if img_name in ["background", "full_img"]: if verbose: print("\t - Smoothing south pole depth for consistency") panorama_depth = smooth_south_pole_depth( panorama_depth, smooth_height_ratio=0.05 ) rays = torch.from_numpy( spherical_uv_to_directions( utils3d.numpy.image_uv(width=width_origin, height=height_origin) ) ).to(next(model.parameters()).device) panorama_depth = ( torch.from_numpy(panorama_depth).to(next(model.parameters()).device) * scale ) return { "type": "", "rgb": torch.from_numpy(image_origin).to(next(model.parameters()).device), "distance": panorama_depth, "rays": rays, "mask": panorama_mask, "splitted_masks": splitted_masks, "splitted_distance_maps": splitted_distance_maps, }