gaur3009's picture
Create app.py
3e98665 verified
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET # Import U²-Net model
# Load U²-Net model
model_path = "u2net_model/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()
def segment_dress(image_np):
"""Detects dress using U²-Net and creates a binary mask."""
# Convert image to tensor
transform_pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((320, 320))
])
image = Image.fromarray(image_np).convert("RGB")
input_tensor = transform_pipeline(image).unsqueeze(0)
# U²-Net inference
with torch.no_grad():
output = model(input_tensor)[0][0].squeeze().cpu().numpy()
# Generate binary mask
dress_mask = (output > 0.5).astype(np.uint8) * 255
dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
return dress_mask
def remove_background(image_np):
"""Removes background and replaces it with white while keeping the dress."""
# Generate dress mask
mask = segment_dress(image_np)
# Make background white
white_bg = np.ones_like(image_np) * 255 # White background
segmented_dress = np.where(mask[..., None] > 128, image_np, white_bg)
return Image.fromarray(segmented_dress)
# Gradio Interface
demo = gr.Interface(
fn=remove_background,
inputs=gr.Image(type="numpy", label="Upload Dress Image"),
outputs=gr.Image(type="pil", label="Dress with White Background"),
title="Dress Segmentation & Background Removal",
description="Upload a dress image, and this AI model will detect the dress and replace the background with white."
)
if __name__ == "__main__":
demo.launch()