|
import io |
|
|
|
import cv2 |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import requests |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoImageProcessor, EomtForUniversalSegmentation |
|
|
|
|
|
print("Loading model...") |
|
model_id = "tue-mps/coco_panoptic_eomt_large_640" |
|
processor = AutoImageProcessor.from_pretrained(model_id) |
|
model = EomtForUniversalSegmentation.from_pretrained(model_id) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
print(f"Model loaded successfully on {device}!") |
|
|
|
|
|
def run_inference(image): |
|
"""Run panoptic segmentation inference""" |
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.inference_mode(): |
|
outputs = model(**inputs) |
|
|
|
target_sizes = [(image.height, image.width)] |
|
preds = processor.post_process_panoptic_segmentation( |
|
outputs, target_sizes=target_sizes |
|
) |
|
|
|
return preds[0] |
|
|
|
|
|
def tensor_to_numpy(tensor): |
|
"""Convert tensor to numpy array, handling CUDA tensors""" |
|
if isinstance(tensor, torch.Tensor): |
|
return tensor.cpu().numpy() |
|
return tensor |
|
|
|
|
|
def visualize_mask(image, segmentation_mask): |
|
"""Show segmentation mask only""" |
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
|
|
|
|
mask_np = tensor_to_numpy(segmentation_mask) |
|
|
|
if mask_np.max() > 0: |
|
ax.imshow(mask_np, cmap="tab20") |
|
else: |
|
|
|
ax.imshow(np.zeros_like(mask_np), cmap="gray") |
|
ax.text( |
|
0.5, |
|
0.5, |
|
"No segments detected", |
|
transform=ax.transAxes, |
|
ha="center", |
|
va="center", |
|
fontsize=16, |
|
color="red", |
|
weight="bold", |
|
) |
|
|
|
ax.axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def visualize_overlay(image, segmentation_mask): |
|
"""Show segmentation overlay on original image""" |
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
|
|
|
|
ax.imshow(image) |
|
|
|
|
|
mask_np = tensor_to_numpy(segmentation_mask) |
|
ax.imshow(mask_np, cmap="tab20", alpha=0.6) |
|
|
|
ax.axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def visualize_contours(image, segmentation_mask): |
|
"""Show contours of segments on original image""" |
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
|
|
|
|
ax.imshow(image) |
|
|
|
|
|
mask_np = tensor_to_numpy(segmentation_mask).astype(np.uint8) |
|
|
|
|
|
unique_segments = np.unique(mask_np) |
|
|
|
|
|
colors = plt.cm.tab20(np.linspace(0, 1, len(unique_segments))) |
|
|
|
|
|
for i, segment_id in enumerate(unique_segments): |
|
if segment_id == 0: |
|
continue |
|
|
|
|
|
binary_mask = (mask_np == segment_id).astype(np.uint8) |
|
|
|
|
|
contours, _ = cv2.findContours( |
|
binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE |
|
) |
|
|
|
|
|
for contour in contours: |
|
if len(contour) > 2: |
|
contour = contour.reshape(-1, 2) |
|
ax.plot( |
|
contour[:, 0], |
|
contour[:, 1], |
|
color=colors[i % len(colors)], |
|
linewidth=2, |
|
alpha=0.8, |
|
) |
|
|
|
ax.axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def visualize_instance_masks(image, segmentation_mask): |
|
"""Show individual instance masks in a grid""" |
|
mask_np = tensor_to_numpy(segmentation_mask) |
|
unique_segments, counts = np.unique(mask_np, return_counts=True) |
|
|
|
|
|
non_bg_indices = unique_segments != 0 |
|
if np.any(non_bg_indices): |
|
top_segments = unique_segments[non_bg_indices][ |
|
np.argsort(counts[non_bg_indices])[-9:] |
|
] |
|
top_counts = counts[non_bg_indices][np.argsort(counts[non_bg_indices])[-9:]] |
|
else: |
|
top_segments = [] |
|
top_counts = [] |
|
|
|
fig, axes = plt.subplots(3, 3, figsize=(15, 15)) |
|
axes = axes.flatten() |
|
|
|
for i, (segment_id, count) in enumerate(zip(top_segments, top_counts)): |
|
binary_mask = (mask_np == segment_id).astype(float) |
|
axes[i].imshow(binary_mask, cmap="Blues") |
|
axes[i].set_title( |
|
f"Segment {segment_id}\nPixels: {count}", fontsize=10, weight="bold" |
|
) |
|
axes[i].axis("off") |
|
|
|
|
|
for i in range(len(top_segments), 9): |
|
axes[i].axis("off") |
|
if i == 0 and len(top_segments) == 0: |
|
axes[i].text( |
|
0.5, |
|
0.5, |
|
"No segments\ndetected", |
|
transform=axes[i].transAxes, |
|
ha="center", |
|
va="center", |
|
fontsize=12, |
|
color="red", |
|
weight="bold", |
|
) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def visualize_edges(image, segmentation_mask): |
|
"""Show edge detection on segmentation boundaries""" |
|
mask_np = tensor_to_numpy(segmentation_mask) |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
|
|
|
|
ax.imshow(image) |
|
|
|
|
|
if mask_np.max() > 0: |
|
edges = cv2.Canny((mask_np * 255 / mask_np.max()).astype(np.uint8), 50, 150) |
|
|
|
|
|
edge_overlay = np.zeros((*edges.shape, 4)) |
|
edge_overlay[edges > 0] = [1, 1, 0, 1] |
|
|
|
|
|
ax.imshow(edge_overlay) |
|
else: |
|
|
|
ax.text( |
|
0.5, |
|
0.5, |
|
"No segments detected", |
|
transform=ax.transAxes, |
|
ha="center", |
|
va="center", |
|
fontsize=16, |
|
color="red", |
|
weight="bold", |
|
) |
|
|
|
ax.axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def visualize_segment_isolation(image, segmentation_mask): |
|
"""Show the largest segment isolated from the rest""" |
|
mask_np = tensor_to_numpy(segmentation_mask) |
|
unique_segments, counts = np.unique(mask_np, return_counts=True) |
|
|
|
|
|
non_bg_indices = unique_segments != 0 |
|
if np.any(non_bg_indices): |
|
largest_segment = unique_segments[non_bg_indices][ |
|
np.argmax(counts[non_bg_indices]) |
|
] |
|
largest_count = counts[non_bg_indices][np.argmax(counts[non_bg_indices])] |
|
else: |
|
largest_segment = unique_segments[np.argmax(counts)] |
|
largest_count = counts[np.argmax(counts)] |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
|
|
|
|
isolated_mask = (mask_np == largest_segment).astype(float) |
|
|
|
if isolated_mask.max() > 0: |
|
ax.imshow(isolated_mask, cmap="Reds") |
|
ax.set_title( |
|
f"Largest Segment (ID: {largest_segment}, Pixels: {largest_count})", |
|
fontsize=14, |
|
weight="bold", |
|
pad=20, |
|
) |
|
else: |
|
ax.imshow(np.zeros_like(isolated_mask), cmap="gray") |
|
ax.text( |
|
0.5, |
|
0.5, |
|
"No segments detected", |
|
transform=ax.transAxes, |
|
ha="center", |
|
va="center", |
|
fontsize=16, |
|
color="red", |
|
weight="bold", |
|
) |
|
|
|
ax.axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def visualize_heatmap(image, segmentation_mask): |
|
"""Show boundary density heatmap""" |
|
mask_np = tensor_to_numpy(segmentation_mask) |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
|
|
if mask_np.max() > 0: |
|
|
|
gradient_magnitude = np.gradient(mask_np.astype(float)) |
|
gradient_magnitude = np.sqrt( |
|
gradient_magnitude[0] ** 2 + gradient_magnitude[1] ** 2 |
|
) |
|
|
|
|
|
im = ax.imshow(gradient_magnitude, cmap="hot") |
|
ax.set_title("Boundary Density Heatmap", fontsize=14, weight="bold", pad=20) |
|
ax.axis("off") |
|
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Gradient Magnitude") |
|
else: |
|
|
|
ax.imshow(np.zeros_like(mask_np), cmap="hot") |
|
ax.text( |
|
0.5, |
|
0.5, |
|
"No segments detected", |
|
transform=ax.transAxes, |
|
ha="center", |
|
va="center", |
|
fontsize=16, |
|
color="red", |
|
weight="bold", |
|
) |
|
ax.axis("off") |
|
|
|
plt.tight_layout() |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def create_visualization(image, viz_type): |
|
"""Create visualization based on selected type""" |
|
if image is None: |
|
return None |
|
|
|
try: |
|
|
|
prediction = run_inference(image) |
|
segmentation_mask = prediction["segmentation"] |
|
|
|
if viz_type == "Mask": |
|
return visualize_mask(image, segmentation_mask) |
|
elif viz_type == "Overlay": |
|
return visualize_overlay(image, segmentation_mask) |
|
elif viz_type == "Contours": |
|
return visualize_contours(image, segmentation_mask) |
|
elif viz_type == "Instance Masks": |
|
return visualize_instance_masks(image, segmentation_mask) |
|
elif viz_type == "Edge Detection": |
|
return visualize_edges(image, segmentation_mask) |
|
elif viz_type == "Segment Isolation": |
|
return visualize_segment_isolation(image, segmentation_mask) |
|
elif viz_type == "Heatmap": |
|
return visualize_heatmap(image, segmentation_mask) |
|
else: |
|
|
|
return visualize_mask(image, segmentation_mask) |
|
|
|
except Exception as e: |
|
print(f"Error in visualization: {e}") |
|
|
|
fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
|
ax.text( |
|
0.5, |
|
0.5, |
|
f"Error during processing:\n{str(e)}", |
|
transform=ax.transAxes, |
|
ha="center", |
|
va="center", |
|
fontsize=12, |
|
color="red", |
|
weight="bold", |
|
) |
|
ax.axis("off") |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
buf.seek(0) |
|
plt.close() |
|
|
|
return Image.open(buf) |
|
|
|
|
|
def load_sample_image(img_path): |
|
"""Load a sample image from URL""" |
|
try: |
|
response = requests.get(img_path, stream=True) |
|
response.raise_for_status() |
|
return Image.open(response.raw) |
|
except Exception as e: |
|
print(f"Error loading image: {e}") |
|
return None |
|
|
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks( |
|
title="Panoptic Segmentation Visualizer", theme=gr.themes.Soft() |
|
) as demo: |
|
gr.Markdown(""" |
|
# ๐จ Panoptic Segmentation Visualizer |
|
|
|
Upload an image and select a visualization type to see different ways of viewing the panoptic segmentation results. |
|
The model used is `tue-mps/coco_panoptic_eomt_large_640`. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(label="Upload Image", type="pil", height=400) |
|
|
|
viz_type = gr.Radio( |
|
choices=[ |
|
"Mask", |
|
"Overlay", |
|
"Contours", |
|
"Instance Masks", |
|
"Edge Detection", |
|
"Segment Isolation", |
|
"Heatmap", |
|
], |
|
label="Visualization Type", |
|
value="Mask", |
|
info="Choose how to visualize the segmentation results", |
|
) |
|
|
|
process_btn = gr.Button( |
|
"๐ Process Image", variant="primary", size="lg" |
|
) |
|
|
|
gr.Markdown(""" |
|
### Visualization Types: |
|
- **Mask**: Segmentation mask with color-coded segments |
|
- **Overlay**: Transparent segmentation overlay on original image |
|
- **Contours**: Segment boundaries outlined on original image |
|
- **Instance Masks**: Individual instance masks in a grid (top 9 by size) |
|
- **Edge Detection**: Segmentation boundaries highlighted in yellow |
|
- **Segment Isolation**: Shows the largest segment isolated from the rest |
|
- **Heatmap**: Boundary density visualization with color mapping |
|
""") |
|
|
|
with gr.Column(scale=2): |
|
output_image = gr.Image( |
|
label="Segmentation Result", type="pil", height=600 |
|
) |
|
|
|
process_btn.click( |
|
fn=create_visualization, |
|
inputs=[image_input, viz_type], |
|
outputs=output_image, |
|
) |
|
|
|
|
|
gr.Markdown("### ๐ธ Try with sample images:") |
|
|
|
sample_images = [ |
|
("http://images.cocodataset.org/val2017/000000039769.jpg", "Cats on Couch"), |
|
("http://images.cocodataset.org/val2017/000000397133.jpg", "Street Scene"), |
|
("http://images.cocodataset.org/val2017/000000037777.jpg", "Living Room"), |
|
( |
|
"http://images.cocodataset.org/val2017/000000174482.jpg", |
|
"Person with Laptop", |
|
), |
|
("http://images.cocodataset.org/val2017/000000000785.jpg", "Dining Table"), |
|
] |
|
|
|
def create_thumbnail_gallery(): |
|
"""Create a gallery of clickable thumbnails""" |
|
gallery_images = [] |
|
for img_url, img_name in sample_images: |
|
try: |
|
img = load_sample_image(img_url) |
|
if img: |
|
|
|
img.thumbnail((200, 200), Image.Resampling.LANCZOS) |
|
gallery_images.append((img, img_name)) |
|
except Exception as e: |
|
print(f"Failed to load {img_name}: {e}") |
|
continue |
|
return gallery_images |
|
|
|
with gr.Row(): |
|
thumbnail_gallery = gr.Gallery( |
|
value=create_thumbnail_gallery(), |
|
label="Sample Images", |
|
show_label=True, |
|
elem_id="thumbnail_gallery", |
|
columns=5, |
|
rows=1, |
|
object_fit="contain", |
|
height=200, |
|
allow_preview=False, |
|
) |
|
|
|
def select_from_gallery(evt: gr.SelectData): |
|
"""Handle gallery selection""" |
|
selected_idx = evt.index |
|
if selected_idx < len(sample_images): |
|
img_url, _ = sample_images[selected_idx] |
|
return load_sample_image(img_url) |
|
return None |
|
|
|
thumbnail_gallery.select(select_from_gallery, outputs=image_input) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |
|
|