import numpy as np import torch import cv2 from PIL import Image from transformers import pipeline import gradio as gr # ===== Device Setup ===== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device_index = 0 if torch.cuda.is_available() else -1 # ===== MiDaS Depth Estimation Setup ===== # Load MiDaS model and transforms midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large") midas.to(device).eval() midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") transform = midas_transforms.dpt_transform # ===== Segmentation Setup ===== segmenter = pipeline( "image-segmentation", model="nvidia/segformer-b0-finetuned-ade-512-512", device=device_index, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32 ) # ===== Utility Functions ===== def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image: width, height = img.size if max(width, height) > max_size: ratio = max_size / max(width, height) new_size = (int(width * ratio), int(height * ratio)) return img.resize(new_size, Image.LANCZOS) return img # ===== Depth Prediction ===== def predict_depth(image: Image.Image) -> Image.Image: # Ensure input is PIL Image img = image.convert('RGB') if not isinstance(image, Image.Image) else image img_np = np.array(img) # Convert to the format expected by MiDaS input_tensor = transform(img_np).to(device) input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor # Predict depth with torch.no_grad(): prediction = midas(input_batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img_np.shape[:2], mode="bicubic", align_corners=False ).squeeze() # Normalize to 0-255 depth_map = prediction.cpu().numpy() depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) depth_map = (depth_map * 255).astype(np.uint8) return Image.fromarray(depth_map) # ===== Segmentation ===== def segment_image(img: Image.Image) -> Image.Image: img = img.convert('RGB') img_resized = resize_image(img) results = segmenter(img_resized) overlay = np.array(img_resized, dtype=np.uint8) for res in results: mask = np.array(res["mask"], dtype=bool) color = np.random.randint(50, 255, 3, dtype=np.uint8) overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8) return Image.fromarray(overlay) # ===== Gradio App ===== def predict_fn(input_img: Image.Image) -> Image.Image: # 1. Compute depth map depth_img = predict_depth(input_img) # 2. Segment the depth map seg_img = segment_image(depth_img) return seg_img iface = gr.Interface( fn=predict_fn, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Image(type="pil", label="Segmented Depth Overlay"), title="Depth-then-Segmentation Pipeline", description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map." ) if __name__ == "__main__": iface.launch()