EgoHackZero
try to add segmentation step
6225fda
import numpy as np
import torch
import cv2
from PIL import Image
from transformers import pipeline
import gradio as gr
# ===== Device Setup =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_index = 0 if torch.cuda.is_available() else -1
# ===== MiDaS Depth Estimation Setup =====
# Load MiDaS model and transforms
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
midas.to(device).eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.dpt_transform
# ===== Segmentation Setup =====
segmenter = pipeline(
"image-segmentation",
model="nvidia/segformer-b0-finetuned-ade-512-512",
device=device_index,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
)
# ===== Utility Functions =====
def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image:
width, height = img.size
if max(width, height) > max_size:
ratio = max_size / max(width, height)
new_size = (int(width * ratio), int(height * ratio))
return img.resize(new_size, Image.LANCZOS)
return img
# ===== Depth Prediction =====
def predict_depth(image: Image.Image) -> Image.Image:
# Ensure input is PIL Image
img = image.convert('RGB') if not isinstance(image, Image.Image) else image
img_np = np.array(img)
# Convert to the format expected by MiDaS
input_tensor = transform(img_np).to(device)
input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor
# Predict depth
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img_np.shape[:2],
mode="bicubic",
align_corners=False
).squeeze()
# Normalize to 0-255
depth_map = prediction.cpu().numpy()
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
depth_map = (depth_map * 255).astype(np.uint8)
return Image.fromarray(depth_map)
# ===== Segmentation =====
def segment_image(img: Image.Image) -> Image.Image:
img = img.convert('RGB')
img_resized = resize_image(img)
results = segmenter(img_resized)
overlay = np.array(img_resized, dtype=np.uint8)
for res in results:
mask = np.array(res["mask"], dtype=bool)
color = np.random.randint(50, 255, 3, dtype=np.uint8)
overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8)
return Image.fromarray(overlay)
# ===== Gradio App =====
def predict_fn(input_img: Image.Image) -> Image.Image:
# 1. Compute depth map
depth_img = predict_depth(input_img)
# 2. Segment the depth map
seg_img = segment_image(depth_img)
return seg_img
iface = gr.Interface(
fn=predict_fn,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Segmented Depth Overlay"),
title="Depth-then-Segmentation Pipeline",
description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map."
)
if __name__ == "__main__":
iface.launch()