Spaces:
Runtime error
Runtime error
import torch | |
import cv2 | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from torchvision import transforms | |
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net model | |
# Load U²-Net model | |
model_path = "u2net_model/u2net.pth" | |
model = U2NET(3, 1) | |
state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix | |
model.load_state_dict(state_dict) | |
model.eval() | |
def segment_dress(image_np): | |
"""Detects dress using U²-Net and creates a binary mask.""" | |
# Convert image to tensor | |
transform_pipeline = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((320, 320)) | |
]) | |
image = Image.fromarray(image_np).convert("RGB") | |
input_tensor = transform_pipeline(image).unsqueeze(0) | |
# U²-Net inference | |
with torch.no_grad(): | |
output = model(input_tensor)[0][0].squeeze().cpu().numpy() | |
# Generate binary mask | |
dress_mask = (output > 0.5).astype(np.uint8) * 255 | |
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST) | |
return dress_mask | |
def remove_background(image_np): | |
"""Removes background and replaces it with white while keeping the dress.""" | |
# Generate dress mask | |
mask = segment_dress(image_np) | |
# Make background white | |
white_bg = np.ones_like(image_np) * 255 # White background | |
segmented_dress = np.where(mask[..., None] > 128, image_np, white_bg) | |
return Image.fromarray(segmented_dress) | |
# Gradio Interface | |
demo = gr.Interface( | |
fn=remove_background, | |
inputs=gr.Image(type="numpy", label="Upload Dress Image"), | |
outputs=gr.Image(type="pil", label="Dress with White Background"), | |
title="Dress Segmentation & Background Removal", | |
description="Upload a dress image, and this AI model will detect the dress and replace the background with white." | |
) | |
if __name__ == "__main__": | |
demo.launch() |