eomt / app.py
Athrael's picture
Upload 3 files
b34d493 verified
import io
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, EomtForUniversalSegmentation
# Load model globally to avoid reloading
print("Loading model...")
model_id = "tue-mps/coco_panoptic_eomt_large_640"
processor = AutoImageProcessor.from_pretrained(model_id)
model = EomtForUniversalSegmentation.from_pretrained(model_id)
# Check for CUDA availability and move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model loaded successfully on {device}!")
def run_inference(image):
"""Run panoptic segmentation inference"""
inputs = processor(images=image, return_tensors="pt")
# Move inputs to the same device as model
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.inference_mode():
outputs = model(**inputs)
target_sizes = [(image.height, image.width)]
preds = processor.post_process_panoptic_segmentation(
outputs, target_sizes=target_sizes
)
return preds[0]
def tensor_to_numpy(tensor):
"""Convert tensor to numpy array, handling CUDA tensors"""
if isinstance(tensor, torch.Tensor):
return tensor.cpu().numpy()
return tensor
def visualize_mask(image, segmentation_mask):
"""Show segmentation mask only"""
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
# Segmentation mask - convert tensor to numpy
mask_np = tensor_to_numpy(segmentation_mask)
if mask_np.max() > 0:
ax.imshow(mask_np, cmap="tab20")
else:
# If no segments, show a blank image with text
ax.imshow(np.zeros_like(mask_np), cmap="gray")
ax.text(
0.5,
0.5,
"No segments detected",
transform=ax.transAxes,
ha="center",
va="center",
fontsize=16,
color="red",
weight="bold",
)
ax.axis("off")
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_overlay(image, segmentation_mask):
"""Show segmentation overlay on original image"""
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
# Original image
ax.imshow(image)
# Overlay segmentation mask with transparency
mask_np = tensor_to_numpy(segmentation_mask)
ax.imshow(mask_np, cmap="tab20", alpha=0.6)
ax.axis("off")
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_contours(image, segmentation_mask):
"""Show contours of segments on original image"""
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
# Original image
ax.imshow(image)
# Convert mask to numpy and find contours
mask_np = tensor_to_numpy(segmentation_mask).astype(np.uint8)
# Find unique segments
unique_segments = np.unique(mask_np)
# Create a colormap for distinct colors
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_segments)))
# Draw contours for each segment
for i, segment_id in enumerate(unique_segments):
if segment_id == 0: # Skip background
continue
# Create binary mask for this segment
binary_mask = (mask_np == segment_id).astype(np.uint8)
# Find contours
contours, _ = cv2.findContours(
binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# Draw contours with unique color for each segment
for contour in contours:
if len(contour) > 2: # Only draw if contour has enough points
contour = contour.reshape(-1, 2)
ax.plot(
contour[:, 0],
contour[:, 1],
color=colors[i % len(colors)],
linewidth=2,
alpha=0.8,
)
ax.axis("off")
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_instance_masks(image, segmentation_mask):
"""Show individual instance masks in a grid"""
mask_np = tensor_to_numpy(segmentation_mask)
unique_segments, counts = np.unique(mask_np, return_counts=True)
# Get top 9 segments by size (excluding background)
non_bg_indices = unique_segments != 0
if np.any(non_bg_indices):
top_segments = unique_segments[non_bg_indices][
np.argsort(counts[non_bg_indices])[-9:]
]
top_counts = counts[non_bg_indices][np.argsort(counts[non_bg_indices])[-9:]]
else:
top_segments = []
top_counts = []
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.flatten()
for i, (segment_id, count) in enumerate(zip(top_segments, top_counts)):
binary_mask = (mask_np == segment_id).astype(float)
axes[i].imshow(binary_mask, cmap="Blues")
axes[i].set_title(
f"Segment {segment_id}\nPixels: {count}", fontsize=10, weight="bold"
)
axes[i].axis("off")
# Fill empty subplots with informative text
for i in range(len(top_segments), 9):
axes[i].axis("off")
if i == 0 and len(top_segments) == 0:
axes[i].text(
0.5,
0.5,
"No segments\ndetected",
transform=axes[i].transAxes,
ha="center",
va="center",
fontsize=12,
color="red",
weight="bold",
)
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_edges(image, segmentation_mask):
"""Show edge detection on segmentation boundaries"""
mask_np = tensor_to_numpy(segmentation_mask)
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
# Original image
ax.imshow(image)
# Edge detection on mask
if mask_np.max() > 0: # Check if mask has any segments
edges = cv2.Canny((mask_np * 255 / mask_np.max()).astype(np.uint8), 50, 150)
# Create a colored edge overlay
edge_overlay = np.zeros((*edges.shape, 4)) # RGBA
edge_overlay[edges > 0] = [1, 1, 0, 1] # Yellow with full alpha
# Overlay edges
ax.imshow(edge_overlay)
else:
# If no segments, just show the original image
ax.text(
0.5,
0.5,
"No segments detected",
transform=ax.transAxes,
ha="center",
va="center",
fontsize=16,
color="red",
weight="bold",
)
ax.axis("off")
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_segment_isolation(image, segmentation_mask):
"""Show the largest segment isolated from the rest"""
mask_np = tensor_to_numpy(segmentation_mask)
unique_segments, counts = np.unique(mask_np, return_counts=True)
# Find largest segment (excluding background)
non_bg_indices = unique_segments != 0
if np.any(non_bg_indices):
largest_segment = unique_segments[non_bg_indices][
np.argmax(counts[non_bg_indices])
]
largest_count = counts[non_bg_indices][np.argmax(counts[non_bg_indices])]
else:
largest_segment = unique_segments[np.argmax(counts)]
largest_count = counts[np.argmax(counts)]
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
# Isolated segment
isolated_mask = (mask_np == largest_segment).astype(float)
if isolated_mask.max() > 0:
ax.imshow(isolated_mask, cmap="Reds")
ax.set_title(
f"Largest Segment (ID: {largest_segment}, Pixels: {largest_count})",
fontsize=14,
weight="bold",
pad=20,
)
else:
ax.imshow(np.zeros_like(isolated_mask), cmap="gray")
ax.text(
0.5,
0.5,
"No segments detected",
transform=ax.transAxes,
ha="center",
va="center",
fontsize=16,
color="red",
weight="bold",
)
ax.axis("off")
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def visualize_heatmap(image, segmentation_mask):
"""Show boundary density heatmap"""
mask_np = tensor_to_numpy(segmentation_mask)
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
if mask_np.max() > 0:
# Calculate gradient magnitude for boundary detection
gradient_magnitude = np.gradient(mask_np.astype(float))
gradient_magnitude = np.sqrt(
gradient_magnitude[0] ** 2 + gradient_magnitude[1] ** 2
)
# Boundary heatmap
im = ax.imshow(gradient_magnitude, cmap="hot")
ax.set_title("Boundary Density Heatmap", fontsize=14, weight="bold", pad=20)
ax.axis("off")
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Gradient Magnitude")
else:
# If no segments, show a blank heatmap
ax.imshow(np.zeros_like(mask_np), cmap="hot")
ax.text(
0.5,
0.5,
"No segments detected",
transform=ax.transAxes,
ha="center",
va="center",
fontsize=16,
color="red",
weight="bold",
)
ax.axis("off")
plt.tight_layout()
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def create_visualization(image, viz_type):
"""Create visualization based on selected type"""
if image is None:
return None
try:
# Run inference
prediction = run_inference(image)
segmentation_mask = prediction["segmentation"]
if viz_type == "Mask":
return visualize_mask(image, segmentation_mask)
elif viz_type == "Overlay":
return visualize_overlay(image, segmentation_mask)
elif viz_type == "Contours":
return visualize_contours(image, segmentation_mask)
elif viz_type == "Instance Masks":
return visualize_instance_masks(image, segmentation_mask)
elif viz_type == "Edge Detection":
return visualize_edges(image, segmentation_mask)
elif viz_type == "Segment Isolation":
return visualize_segment_isolation(image, segmentation_mask)
elif viz_type == "Heatmap":
return visualize_heatmap(image, segmentation_mask)
else:
# Default fallback
return visualize_mask(image, segmentation_mask)
except Exception as e:
print(f"Error in visualization: {e}")
# Return a simple error visualization
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
ax.text(
0.5,
0.5,
f"Error during processing:\n{str(e)}",
transform=ax.transAxes,
ha="center",
va="center",
fontsize=12,
color="red",
weight="bold",
)
ax.axis("off")
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150)
buf.seek(0)
plt.close()
return Image.open(buf)
def load_sample_image(img_path):
"""Load a sample image from URL"""
try:
response = requests.get(img_path, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
print(f"Error loading image: {e}")
return None
# Create Gradio interface
def create_interface():
with gr.Blocks(
title="Panoptic Segmentation Visualizer", theme=gr.themes.Soft()
) as demo:
gr.Markdown("""
# ๐ŸŽจ Panoptic Segmentation Visualizer
Upload an image and select a visualization type to see different ways of viewing the panoptic segmentation results.
The model used is `tue-mps/coco_panoptic_eomt_large_640`.
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload Image", type="pil", height=400)
viz_type = gr.Radio(
choices=[
"Mask",
"Overlay",
"Contours",
"Instance Masks",
"Edge Detection",
"Segment Isolation",
"Heatmap",
],
label="Visualization Type",
value="Mask",
info="Choose how to visualize the segmentation results",
)
process_btn = gr.Button(
"๐Ÿš€ Process Image", variant="primary", size="lg"
)
gr.Markdown("""
### Visualization Types:
- **Mask**: Segmentation mask with color-coded segments
- **Overlay**: Transparent segmentation overlay on original image
- **Contours**: Segment boundaries outlined on original image
- **Instance Masks**: Individual instance masks in a grid (top 9 by size)
- **Edge Detection**: Segmentation boundaries highlighted in yellow
- **Segment Isolation**: Shows the largest segment isolated from the rest
- **Heatmap**: Boundary density visualization with color mapping
""")
with gr.Column(scale=2):
output_image = gr.Image(
label="Segmentation Result", type="pil", height=600
)
process_btn.click(
fn=create_visualization,
inputs=[image_input, viz_type],
outputs=output_image,
)
# Sample images with thumbnails
gr.Markdown("### ๐Ÿ“ธ Try with sample images:")
sample_images = [
("http://images.cocodataset.org/val2017/000000039769.jpg", "Cats on Couch"),
("http://images.cocodataset.org/val2017/000000397133.jpg", "Street Scene"),
("http://images.cocodataset.org/val2017/000000037777.jpg", "Living Room"),
(
"http://images.cocodataset.org/val2017/000000174482.jpg",
"Person with Laptop",
),
("http://images.cocodataset.org/val2017/000000000785.jpg", "Dining Table"),
]
def create_thumbnail_gallery():
"""Create a gallery of clickable thumbnails"""
gallery_images = []
for img_url, img_name in sample_images:
try:
img = load_sample_image(img_url)
if img:
# Resize to thumbnail while maintaining aspect ratio
img.thumbnail((200, 200), Image.Resampling.LANCZOS)
gallery_images.append((img, img_name))
except Exception as e:
print(f"Failed to load {img_name}: {e}")
continue
return gallery_images
with gr.Row():
thumbnail_gallery = gr.Gallery(
value=create_thumbnail_gallery(),
label="Sample Images",
show_label=True,
elem_id="thumbnail_gallery",
columns=5,
rows=1,
object_fit="contain",
height=200,
allow_preview=False,
)
def select_from_gallery(evt: gr.SelectData):
"""Handle gallery selection"""
selected_idx = evt.index
if selected_idx < len(sample_images):
img_url, _ = sample_images[selected_idx]
return load_sample_image(img_url)
return None
thumbnail_gallery.select(select_from_gallery, outputs=image_input)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)