Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Load the segmentation pipeline | |
| pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") | |
| # Your predefined label dictionary | |
| label_dict = { | |
| 0: "Background", | |
| 1: "Hat", | |
| 2: "Hair", | |
| 3: "Sunglasses", | |
| 4: "Upper-clothes", | |
| 5: "Skirt", | |
| 6: "Pants", | |
| 7: "Dress", | |
| 8: "Belt", | |
| 9: "Left-shoe", | |
| 10: "Right-shoe", | |
| 11: "Face", | |
| 12: "Left-leg", | |
| 13: "Right-leg", | |
| 14: "Left-arm", | |
| 15: "Right-arm", | |
| 16: "Bag", | |
| 17: "Scarf", | |
| } | |
| # Function to process the image and generate the segmentation map | |
| # Function to process the image and generate the segmentation map | |
| def segment_image(image): | |
| # Perform segmentation | |
| result = pipe(image) | |
| # Initialize an empty array for the segmentation map | |
| image_width, image_height = result[0]["mask"].size | |
| segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) | |
| # Combine masks into a single segmentation map | |
| for entry in result: | |
| label = entry["label"] # Get the label (e.g., "Hair", "Upper-clothes") | |
| mask = np.array(entry["mask"]) # Convert PIL Image to NumPy array | |
| # Find the index of the label in the original label dictionary | |
| class_idx = [key for key, value in label_dict.items() if value == label][0] | |
| # Assign the correct class index to the segmentation map | |
| segmentation_map[mask > 0] = class_idx | |
| # Get the unique class indices in the segmentation map | |
| unique_classes = np.unique(segmentation_map) | |
| # Print the names of the detected classes | |
| print("Detected Classes:") | |
| for class_idx in unique_classes: | |
| print(f"- {label_dict[class_idx]}") | |
| # Create a matplotlib figure and visualize the segmentation map | |
| plt.figure(figsize=(8, 8)) | |
| plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap | |
| # Get the unique class indices in the segmentation map | |
| unique_classes = np.unique(segmentation_map) | |
| # Filter the label dictionary to include only detected classes | |
| filtered_labels = {idx: label_dict[idx] for idx in unique_classes} | |
| # Create a dynamic colorbar with only the detected classes | |
| cbar = plt.colorbar(ticks=unique_classes) | |
| cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes]) | |
| plt.title("Segmented Image with Detected Classes") | |
| plt.axis("off") | |
| plt.savefig("segmented_output.png", bbox_inches="tight") | |
| plt.close() | |
| return Image.open("segmented_output.png") | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=segment_image, | |
| inputs=gr.Image(type="pil"), # Input is an image | |
| outputs=gr.Image(type="pil"), # Output is an image with the colormap | |
| #examples=["example_image.jpg"], # Use the saved image as an example | |
| examples=["1.jpg", "2.jpg", "3.jpg"], | |
| title="Clothes Segmentation with Colormap", | |
| description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." | |
| ) | |
| # Launch the app | |
| interface.launch() | |