import torch import cv2 import numpy as np import gradio as gr from PIL import Image from torchvision import transforms from cloth_segmentation.networks.u2net import U2NET # Import U²-Net model # Load U²-Net model model_path = "u2net_model/u2net.pth" model = U2NET(3, 1) state_dict = torch.load(model_path, map_location=torch.device('cpu')) state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix model.load_state_dict(state_dict) model.eval() def segment_dress(image_np): """Detects dress using U²-Net and creates a binary mask.""" # Convert image to tensor transform_pipeline = transforms.Compose([ transforms.ToTensor(), transforms.Resize((320, 320)) ]) image = Image.fromarray(image_np).convert("RGB") input_tensor = transform_pipeline(image).unsqueeze(0) # U²-Net inference with torch.no_grad(): output = model(input_tensor)[0][0].squeeze().cpu().numpy() # Generate binary mask dress_mask = (output > 0.5).astype(np.uint8) * 255 dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST) return dress_mask def remove_background(image_np): """Removes background and replaces it with white while keeping the dress.""" # Generate dress mask mask = segment_dress(image_np) # Make background white white_bg = np.ones_like(image_np) * 255 # White background segmented_dress = np.where(mask[..., None] > 128, image_np, white_bg) return Image.fromarray(segmented_dress) # Gradio Interface demo = gr.Interface( fn=remove_background, inputs=gr.Image(type="numpy", label="Upload Dress Image"), outputs=gr.Image(type="pil", label="Dress with White Background"), title="Dress Segmentation & Background Removal", description="Upload a dress image, and this AI model will detect the dress and replace the background with white." ) if __name__ == "__main__": demo.launch()