Spaces:
Build error
Build error
import cv2 | |
import numpy as np | |
from typing import Optional, Literal | |
import random | |
import matplotlib | |
import open3d as o3d | |
import torch | |
import torch.nn.functional as F | |
from collections import defaultdict | |
def spherical_uv_to_directions(uv: np.ndarray): | |
r""" | |
Convert spherical UV coordinates to 3D directions. | |
Args: | |
uv (np.ndarray): UV coordinates in the range [0, 1]. Shape: (H, W, 2). | |
Returns: | |
directions (np.ndarray): 3D directions corresponding to the UV coordinates. Shape: (H, W, 3). | |
""" | |
theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi | |
directions = np.stack([np.sin(phi) * np.cos(theta), | |
np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1) | |
return directions | |
def depth_match(init_pred: dict, bg_pred: dict, mask: np.ndarray, quantile: float = 0.3) -> dict: | |
r""" | |
Match the background depth map to the scale of the initial depth map. | |
Args: | |
init_pred (dict): Initial depth prediction containing "distance" key. | |
bg_pred (dict): Background depth prediction containing "distance" key. | |
mask (np.ndarray): Binary mask indicating valid pixels in the background depth map. | |
quantile (float): Quantile to use for selecting the depth range for scale matching. | |
Returns: | |
bg_pred (dict): Background depth prediction with adjusted "distance" key. | |
""" | |
valid_mask = mask > 0 | |
init_distance = init_pred["distance"][valid_mask] | |
bg_distance = bg_pred["distance"][valid_mask] | |
init_mask = init_distance < torch.quantile(init_distance, quantile) | |
bg_mask = bg_distance < torch.quantile(bg_distance, quantile) | |
scale = init_distance[init_mask].median() / bg_distance[bg_mask].median() | |
bg_pred["distance"] *= scale | |
return bg_pred | |
def _fill_small_boundary_spikes( | |
mesh: o3d.geometry.TriangleMesh, | |
max_bridge_dist: float, | |
repeat_times: int = 3, | |
max_connection_step: int = 8, | |
) -> o3d.geometry.TriangleMesh: | |
r""" | |
Fill small boundary spikes in a mesh by creating triangles between boundary vertices. | |
Args: | |
mesh (o3d.geometry.TriangleMesh): The input mesh to process. | |
max_bridge_dist (float): Maximum distance allowed for bridging boundary vertices. | |
repeat_times (int): Number of iterations to repeat the filling process. | |
max_connection_step (int): Maximum number of steps to connect boundary vertices. | |
Returns: | |
o3d.geometry.TriangleMesh: The mesh with small boundary spikes filled. | |
""" | |
for iteration in range(repeat_times): | |
if not mesh.has_triangles() or not mesh.has_vertices(): | |
return mesh | |
vertices = np.asarray(mesh.vertices) | |
triangles = np.asarray(mesh.triangles) | |
# 1. Identify boundary edges | |
edge_to_triangle_count = defaultdict(int) | |
for tri_idx, tri in enumerate(triangles): | |
for i in range(3): | |
v1_idx, v2_idx = tri[i], tri[(i + 1) % 3] | |
edge = tuple(sorted((v1_idx, v2_idx))) | |
edge_to_triangle_count[edge] += 1 | |
boundary_edges = [edge for edge, | |
count in edge_to_triangle_count.items() if count == 1] | |
if not boundary_edges: | |
return mesh | |
# 2. Create an adjacency list for boundary vertices using only boundary edges | |
boundary_adj = defaultdict(list) | |
for v1_idx, v2_idx in boundary_edges: | |
boundary_adj[v1_idx].append(v2_idx) | |
boundary_adj[v2_idx].append(v1_idx) | |
# 3. Process boundary vertices with new smooth filling algorithm | |
new_triangles_list = [] | |
edge_added = defaultdict(bool) | |
# print(f"DEBUG: Found {len(boundary_edges)} boundary edges.") | |
# print(f"DEBUG: Max bridge distance set to: {max_bridge_dist}") | |
new_triangles_added_count = 0 | |
for v_curr_idx, neighbors in boundary_adj.items(): | |
if len(neighbors) != 2: # Only process vertices with exactly 2 boundary neighbors | |
continue | |
v_a_idx, v_b_idx = neighbors[0], neighbors[1] | |
# Skip if these vertices already form a triangle | |
potential_edge = tuple(sorted((v_a_idx, v_b_idx))) | |
if edge_to_triangle_count[potential_edge] > 0 or edge_added[potential_edge]: | |
continue | |
# Calculate distances | |
v_curr_coord = vertices[v_curr_idx] | |
v_a_coord = vertices[v_a_idx] | |
v_b_coord = vertices[v_b_idx] | |
dist_a_b = np.linalg.norm(v_a_coord - v_b_coord) | |
# Skip if distance exceeds threshold | |
if dist_a_b > max_bridge_dist: | |
continue | |
# Create simple triangle (v_a, v_b, v_curr) | |
new_triangles_list.append([v_a_idx, v_b_idx, v_curr_idx]) | |
new_triangles_added_count += 1 | |
edge_added[potential_edge] = True | |
# Mark edges as processed | |
edge_added[tuple(sorted((v_curr_idx, v_a_idx)))] = True | |
edge_added[tuple(sorted((v_curr_idx, v_b_idx)))] = True | |
# 4. Now process multi-step connections for better smoothing | |
# First build boundary chains for multi-step connections | |
boundary_loops = [] | |
visited_vertices = set() | |
# Find boundary vertices with exactly 2 neighbors (part of continuous chains) | |
chain_starts = [v for v in boundary_adj if len( | |
boundary_adj[v]) == 2 and v not in visited_vertices] | |
for start_vertex in chain_starts: | |
if start_vertex in visited_vertices: | |
continue | |
chain = [] | |
curr_vertex = start_vertex | |
# Follow the chain in one direction | |
while curr_vertex not in visited_vertices: | |
visited_vertices.add(curr_vertex) | |
chain.append(curr_vertex) | |
next_candidates = [ | |
n for n in boundary_adj[curr_vertex] if n not in visited_vertices] | |
if not next_candidates: | |
break | |
curr_vertex = next_candidates[0] | |
if len(chain) >= 3: | |
boundary_loops.append(chain) | |
# Process each boundary chain for multi-step smoothing | |
for chain in boundary_loops: | |
chain_length = len(chain) | |
# Skip very small chains | |
if chain_length < 3: | |
continue | |
# Compute multi-step connections | |
max_step = min(max_connection_step, chain_length - 1) | |
for i in range(chain_length): | |
anchor_idx = chain[i] | |
anchor_coord = vertices[anchor_idx] | |
for step in range(3, max_step + 1): | |
if i + step >= chain_length: | |
break | |
far_idx = chain[i + step] | |
far_coord = vertices[far_idx] | |
# Check distance criteria | |
dist_anchor_far = np.linalg.norm(anchor_coord - far_coord) | |
if dist_anchor_far > max_bridge_dist * step: | |
continue | |
# Check if anchor and far are already connected | |
edge_anchor_far = tuple(sorted((anchor_idx, far_idx))) | |
if edge_to_triangle_count[edge_anchor_far] > 0 or edge_added[edge_anchor_far]: | |
continue | |
# Create fan triangles | |
fan_valid = True | |
fan_triangles = [] | |
prev_mid_idx = anchor_idx | |
for j in range(1, step): | |
mid_idx = chain[i + j] | |
if prev_mid_idx != anchor_idx: | |
tri_edge1 = tuple(sorted((anchor_idx, mid_idx))) | |
tri_edge2 = tuple(sorted((prev_mid_idx, mid_idx))) | |
# Check if edges already exist (not created by our fan) | |
if (edge_to_triangle_count[tri_edge1] > 0 and not edge_added[tri_edge1]) or \ | |
(edge_to_triangle_count[tri_edge2] > 0 and not edge_added[tri_edge2]): | |
fan_valid = False | |
break | |
fan_triangles.append( | |
[anchor_idx, prev_mid_idx, mid_idx]) | |
prev_mid_idx = mid_idx | |
# Add final triangle to connect to far_idx | |
if fan_valid: | |
fan_triangles.append( | |
[anchor_idx, prev_mid_idx, far_idx]) | |
# Add all fan triangles if valid | |
if fan_valid and fan_triangles: | |
for triangle in fan_triangles: | |
v_a, v_b, v_c = triangle | |
edge_ab = tuple(sorted((v_a, v_b))) | |
edge_bc = tuple(sorted((v_b, v_c))) | |
edge_ac = tuple(sorted((v_a, v_c))) | |
new_triangles_list.append(triangle) | |
new_triangles_added_count += 1 | |
edge_added[edge_ab] = True | |
edge_added[edge_bc] = True | |
edge_added[edge_ac] = True | |
# Once we've added a fan, move to the next anchor | |
break | |
if new_triangles_added_count == 0: | |
break | |
# Update the mesh with new triangles | |
if new_triangles_list: | |
all_triangles_np = np.vstack( | |
(triangles, np.array(new_triangles_list, dtype=np.int32))) | |
final_mesh = o3d.geometry.TriangleMesh() | |
final_mesh.vertices = o3d.utility.Vector3dVector(vertices) | |
final_mesh.triangles = o3d.utility.Vector3iVector(all_triangles_np) | |
if mesh.has_vertex_colors(): | |
final_mesh.vertex_colors = mesh.vertex_colors | |
# Clean up the mesh | |
final_mesh.remove_degenerate_triangles() | |
final_mesh.remove_unreferenced_vertices() | |
mesh = final_mesh | |
return mesh | |
def pano_sheet_warping( | |
rgb: torch.Tensor, # (H, W, 3) RGB image, values [0, 1] | |
distance: torch.Tensor, # (H, W) Distance map | |
rays: torch.Tensor, # (H, W, 3) Ray directions (unit vectors ideally) | |
# (H, W) Optional boolean mask | |
excluded_region_mask: Optional[torch.Tensor] = None, | |
max_size: int = 4096, # Max dimension for resizing | |
device: Literal["cuda", "cpu"] = "cuda", # Computation device | |
# Max distance to bridge boundary vertices | |
connect_boundary_max_dist: Optional[float] = 0.5, | |
connect_boundary_repeat_times: int = 2 | |
) -> o3d.geometry.TriangleMesh: | |
r""" | |
Converts panoramic RGBD data (image, distance, rays) into an Open3D mesh. | |
Args: | |
image: Input RGB image tensor (H, W, 3), uint8 or float [0, 255]. | |
distance: Input distance map tensor (H, W). | |
rays: Input ray directions tensor (H, W, 3). Assumed to originate from (0,0,0). | |
excluded_region_mask: Optional boolean mask tensor (H, W). True values indicate regions to potentially exclude. | |
max_size: Maximum size (height or width) to resize inputs to. | |
device: The torch device ('cuda' or 'cpu') to use for computations. | |
Returns: | |
An Open3D TriangleMesh object. | |
""" | |
assert rgb.ndim == 3 and rgb.shape[2] == 3, "Image must be HxWx3" | |
assert distance.ndim == 2, "Distance must be HxW" | |
assert rays.ndim == 3 and rays.shape[2] == 3, "Rays must be HxWx3" | |
assert ( | |
rgb.shape[:2] == distance.shape[:2] == rays.shape[:2] | |
), "Input shapes must match" | |
mask = excluded_region_mask | |
if mask is not None: | |
assert ( | |
mask.ndim == 2 and mask.shape[:2] == rgb.shape[:2] | |
), "Mask shape must match" | |
assert mask.dtype == torch.bool, "Mask must be a boolean tensor" | |
rgb = rgb.to(device) | |
distance = distance.to(device) | |
rays = rays.to(device) | |
if mask is not None: | |
mask = mask.to(device) | |
H, W = distance.shape | |
if max(H, W) > max_size: | |
scale = max_size / max(H, W) | |
else: | |
scale = 1.0 | |
# --- Resize Inputs --- | |
rgb_nchw = rgb.permute(2, 0, 1).unsqueeze(0) | |
distance_nchw = distance.unsqueeze(0).unsqueeze(0) | |
rays_nchw = rays.permute(2, 0, 1).unsqueeze(0) | |
rgb_resized = ( | |
F.interpolate( | |
rgb_nchw, | |
scale_factor=scale, | |
mode="bilinear", | |
align_corners=False, | |
recompute_scale_factor=False, | |
) | |
.squeeze(0) | |
.permute(1, 2, 0) | |
) | |
distance_resized = ( | |
F.interpolate( | |
distance_nchw, | |
scale_factor=scale, | |
mode="bilinear", | |
align_corners=False, | |
recompute_scale_factor=False, | |
) | |
.squeeze(0) | |
.squeeze(0) | |
) | |
rays_resized_nchw = F.interpolate( | |
rays_nchw, | |
scale_factor=scale, | |
mode="bilinear", | |
align_corners=False, | |
recompute_scale_factor=False, | |
) | |
# IMPORTANT: Renormalize ray directions after interpolation | |
rays_resized = rays_resized_nchw.squeeze(0).permute(1, 2, 0) | |
rays_norm = torch.linalg.norm(rays_resized, dim=-1, keepdim=True) | |
rays_resized = rays_resized / (rays_norm + 1e-8) | |
if mask is not None: | |
mask_resized = ( | |
F.interpolate( | |
# Needs float for interpolation | |
mask.unsqueeze(0).unsqueeze(0).float(), | |
scale_factor=scale, | |
mode="nearest", # Or 'nearest' if sharp boundaries are critical | |
# align_corners=False, | |
recompute_scale_factor=False, | |
) | |
.squeeze(0) | |
.squeeze(0) | |
) | |
mask_resized = mask_resized > 0.5 # Convert back to boolean | |
else: | |
mask_resized = None | |
H_new, W_new = distance_resized.shape # Get new dimensions | |
# --- Calculate 3D Vertices --- | |
# Vertex position = origin + distance * ray_direction | |
# Assuming origin is (0, 0, 0) | |
distance_flat = distance_resized.reshape(-1, 1) # (H*W, 1) | |
rays_flat = rays_resized.reshape(-1, 3) # (H*W, 3) | |
vertices = distance_flat * rays_flat # (H*W, 3) | |
vertex_colors = rgb_resized.reshape(-1, 3) # (H*W, 3) | |
# --- Generate Mesh Faces (Triangles from Quads) --- | |
# Vectorized approach for generating faces, including seam connection | |
# Rows for the top of quads | |
row_indices = torch.arange(0, H_new - 1, device=device) | |
# Columns for the left of quads (includes last col for wrapping) | |
col_indices = torch.arange(0, W_new, device=device) | |
# Create 2D grids of row and column coordinates for quad corners | |
# These represent the (row, col) of the top-left vertex of each quad | |
# Shape: (H_new-1, W_new) | |
quad_row_coords = row_indices.view(-1, 1).expand(-1, W_new) | |
quad_col_coords = col_indices.view( | |
1, -1).expand(H_new-1, -1) # Shape: (H_new-1, W_new) | |
# Top-left vertex indices | |
tl_row, tl_col = quad_row_coords, quad_col_coords | |
# Top-right vertex indices (with wrap-around) | |
tr_row, tr_col = quad_row_coords, (quad_col_coords + 1) % W_new | |
# Bottom-left vertex indices | |
bl_row, bl_col = (quad_row_coords + 1), quad_col_coords | |
# Bottom-right vertex indices (with wrap-around) | |
br_row, br_col = (quad_row_coords + 1), (quad_col_coords + 1) % W_new | |
# Convert 2D (row, col) coordinates to 1D vertex indices | |
tl = tl_row * W_new + tl_col | |
tr = tr_row * W_new + tr_col | |
bl = bl_row * W_new + bl_col | |
br = br_row * W_new + br_col | |
# Apply mask if provided | |
if mask_resized is not None: | |
# Get mask values for each corner of the quads | |
mask_tl_vals = mask_resized[tl_row, tl_col] | |
mask_tr_vals = mask_resized[tr_row, tr_col] | |
mask_bl_vals = mask_resized[bl_row, bl_col] | |
mask_br_vals = mask_resized[br_row, br_col] | |
# A quad is kept if none of its vertices are masked | |
# Shape: (H_new-1, W_new) | |
quad_keep_mask = ~(mask_tl_vals | mask_tr_vals | | |
mask_bl_vals | mask_br_vals) | |
# Filter vertex indices based on the keep mask | |
tl = tl[quad_keep_mask] # Result is flattened | |
tr = tr[quad_keep_mask] | |
bl = bl[quad_keep_mask] | |
br = br[quad_keep_mask] | |
else: | |
# If no mask, flatten all potential quads' vertex indices | |
tl = tl.flatten() | |
tr = tr.flatten() | |
bl = bl.flatten() | |
br = br.flatten() | |
# Create triangles (two per quad) | |
# Using the same winding order as before: (tl, tr, bl) and (tr, br, bl) | |
tri1 = torch.stack([tl, tr, bl], dim=1) | |
tri2 = torch.stack([tr, br, bl], dim=1) | |
faces = torch.cat([tri1, tri2], dim=0) | |
mesh_o3d = o3d.geometry.TriangleMesh() | |
mesh_o3d.vertices = o3d.utility.Vector3dVector(vertices.cpu().numpy()) | |
mesh_o3d.triangles = o3d.utility.Vector3iVector(faces.cpu().numpy()) | |
mesh_o3d.vertex_colors = o3d.utility.Vector3dVector( | |
vertex_colors.cpu().numpy()) | |
mesh_o3d.remove_unreferenced_vertices() | |
mesh_o3d.remove_degenerate_triangles() | |
if connect_boundary_max_dist is not None and connect_boundary_max_dist > 0: | |
mesh_o3d = _fill_small_boundary_spikes( | |
mesh_o3d, connect_boundary_max_dist, connect_boundary_repeat_times) | |
# Recompute normals after potential modification, if mesh still valid | |
if mesh_o3d.has_triangles() and mesh_o3d.has_vertices(): | |
mesh_o3d.compute_vertex_normals() | |
# Also computes triangle normals if vertex normals are computed | |
mesh_o3d.compute_triangle_normals() | |
return mesh_o3d | |
def get_no_fg_img(no_fg1_img, no_fg2_img, full_img): | |
r"""Get the image without foreground objects based on available inputs. | |
Args: | |
no_fg1_img: Image with foreground layer 1 removed | |
no_fg2_img: Image with foreground layer 2 removed | |
full_img: Original full image | |
Returns: | |
Image without foreground objects, defaulting to full image if no fg-removed images available | |
""" | |
fg_status = None | |
if no_fg1_img is not None and no_fg2_img is not None: | |
no_fg_img = no_fg2_img | |
fg_status = "both_fg1_fg2" | |
elif no_fg1_img is not None and no_fg2_img is None: | |
no_fg_img = no_fg1_img | |
fg_status = "only_fg1" | |
elif no_fg1_img is None and no_fg2_img is not None: | |
no_fg_img = no_fg2_img | |
fg_status = "only_fg2" | |
else: | |
no_fg_img = full_img | |
fg_status = "no_fg" | |
assert fg_status is not None | |
return no_fg_img, fg_status | |
def get_fg_mask(fg1_mask, fg2_mask): | |
r""" | |
Combine foreground masks from two layers. | |
Args: | |
fg1_mask: Foreground mask for layer 1 | |
fg2_mask: Foreground mask for layer 2 | |
Returns: | |
Combined foreground mask, or None if both are None | |
""" | |
if fg1_mask is not None and fg2_mask is not None: | |
fg_mask = np.logical_or(fg1_mask, fg2_mask) | |
elif fg1_mask is not None: | |
fg_mask = fg1_mask | |
elif fg2_mask is not None: | |
fg_mask = fg2_mask | |
else: | |
fg_mask = None | |
if fg_mask is not None: | |
fg_mask = fg_mask.astype(np.bool_).astype(np.uint8) | |
return fg_mask | |
def get_bg_mask(sky_mask, fg_mask, kernel_scale, dilation_kernel_size: int = 3): | |
r""" | |
Generate background mask based on sky and foreground masks. | |
Args: | |
sky_mask: Sky mask (boolean array) | |
fg_mask: Foreground mask (boolean array) | |
kernel_scale: Scale factor for the kernel size | |
dilation_kernel_size: The size of the dilation kernel. | |
Returns: | |
Background mask as a boolean array, where True indicates background pixels. | |
""" | |
kernel_size = dilation_kernel_size * kernel_scale | |
if fg_mask is not None: | |
bg_mask = np.logical_and( | |
(1 - cv2.dilate(fg_mask, | |
np.ones((kernel_size, kernel_size), np.uint8), iterations=1)), | |
(1 - sky_mask), | |
).astype(np.uint8) | |
else: | |
bg_mask = 1 - sky_mask | |
return bg_mask | |
def get_filtered_mask(disparity, beta=100, alpha_threshold=0.3, device="cuda"): | |
""" | |
filter the disparity map using sobel kernel, then mask out the edge (depth discontinuity) | |
Args: | |
disparity: Disparity map in BHWC format, shape [b, h, w, 1] | |
beta: Exponential decay factor for the Sobel magnitude | |
alpha_threshold: Threshold for visibility mask | |
device: Device to perform computations on, either 'cuda' or 'cpu' | |
Returns: | |
vis_mask: Visibility mask in BHWC format, shape [b, h, w, 1] | |
""" | |
b, h, w, _ = disparity.size() | |
# Permute to NCHW format: [b, 1, h, w] | |
disparity_nchw = disparity.permute(0, 3, 1, 2) | |
# Pad H and W dimensions with replicate padding | |
disparity_padded = F.pad( | |
disparity_nchw, (2, 2, 2, 2), mode="replicate" | |
) # Pad last two dims (W, H), [b, 1, h+4, w+4] | |
kernel_x = ( | |
torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]) | |
.unsqueeze(0) | |
.unsqueeze(0) | |
.float() | |
.to(device) | |
) | |
kernel_y = ( | |
torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]) | |
.unsqueeze(0) | |
.unsqueeze(0) | |
.float() | |
.to(device) | |
) | |
# Apply Sobel filters | |
sobel_x = F.conv2d( | |
disparity_padded, kernel_x, padding=(1, 1) | |
) # Output: [b, 1, h+4, w+4] # Corrected padding | |
sobel_y = F.conv2d( | |
disparity_padded, kernel_y, padding=(1, 1) | |
) # Output: [b, 1, h+4, w+4] # Corrected padding | |
# Calculate magnitude | |
sobel_mag_padded = torch.sqrt( | |
sobel_x**2 + sobel_y**2 | |
) # Shape: [b, 1, h+4, w+4] | |
# Remove padding | |
sobel_mag = sobel_mag_padded[ | |
:, :, 2:-2, 2:-2 | |
] # Shape: [b, 1, h, w] # Adjusted slicing | |
# Calculate alpha and mask | |
alpha = torch.exp(-1.0 * beta * sobel_mag) # Shape: [b, 1, h, w] | |
vis_mask_nchw = torch.greater(alpha, alpha_threshold).float() | |
# Permute back to BHWC format: [b, h, w, 1] | |
vis_mask = vis_mask_nchw.permute(0, 2, 3, 1) | |
assert vis_mask.shape == disparity.shape # Ensure output shape matches input | |
return vis_mask | |
def sheet_warping( | |
predictions, excluded_region_mask=None, | |
connect_boundary_max_dist=0.5, | |
connect_boundary_repeat_times=2, | |
max_size=4096, | |
) -> o3d.geometry.TriangleMesh: | |
r""" | |
Convert depth predictions to a 3D mesh. | |
Args: | |
predictions: Dictionary containing: | |
- "rgb": RGB image tensor of shape (H, W, 3) with | |
values in [0, 255] (uint8) or [0, 1] (float). | |
- "distance": Distance map tensor of shape (H, W). | |
- "rays": Ray directions tensor of shape (H, W, 3). | |
excluded_region_mask: Optional boolean mask tensor of shape (H, W). | |
connect_boundary_max_dist: Maximum distance to bridge boundary vertices. | |
connect_boundary_repeat_times: Number of iterations to repeat the boundary connection. | |
max_size: Maximum size (height or width) to resize inputs to. | |
Returns: | |
An Open3D TriangleMesh object. | |
""" | |
rgb = predictions["rgb"] / 255.0 | |
distance = predictions["distance"] | |
rays = predictions["rays"] | |
mesh = pano_sheet_warping( | |
rgb, | |
distance, | |
rays, | |
excluded_region_mask, | |
connect_boundary_max_dist=connect_boundary_max_dist, | |
connect_boundary_repeat_times=connect_boundary_repeat_times, | |
max_size=max_size | |
) | |
return mesh | |
def seed_all(seed: int = 0): | |
r""" | |
Set random seeds of all components. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def colorize_depth_maps( | |
depth: np.ndarray, | |
mask: np.ndarray = None, | |
normalize: bool = True, | |
cmap: str = 'Spectral' | |
) -> np.ndarray: | |
r""" | |
Colorize depth maps using a colormap. | |
Args: | |
depth (np.ndarray): Depth map to colorize, shape (H, W). | |
mask (np.ndarray, optional): Optional mask to apply to the depth map, shape (H, W). | |
normalize (bool): Whether to normalize the depth values before colorization. | |
cmap (str): Name of the colormap to use. | |
Returns: | |
np.ndarray: Colorized depth map, shape (H, W, 3). | |
""" | |
# moge vis function | |
if mask is None: | |
depth = np.where(depth > 0, depth, np.nan) | |
else: | |
depth = np.where((depth > 0) & mask, depth, np.nan) | |
# Convert depth to disparity (inverse of depth) | |
disp = 1 / depth # Closer objects have higher disparity values | |
# Set invalid disparity values to the 0.1% quantile (avoids extreme outliers) | |
if mask is not None: | |
disp[~((depth > 0) & mask)] = np.nanquantile(disp, 0.001) | |
# Normalize disparity values to [0,1] range if requested | |
if normalize: | |
min_disp, max_disp = np.nanquantile( | |
disp, 0.001), np.nanquantile(disp, 0.99) | |
disp = (disp - min_disp) / (max_disp - min_disp) | |
# Apply colormap (inverted so closer=warmer colors) | |
# Note: matplotlib colormaps return RGBA in [0,1] range | |
colored = np.nan_to_num( | |
matplotlib.colormaps[cmap]( | |
1.0 - disp)[..., :3], # Invert and drop alpha | |
nan=0 # Replace NaN with black | |
) | |
# Convert to uint8 and ensure contiguous memory layout | |
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) | |
return colored | |