File size: 2,015 Bytes
3e98665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from cloth_segmentation.networks.u2net import U2NET  # Import U²-Net model

# Load U²-Net model
model_path = "u2net_model/u2net.pth"
model = U2NET(3, 1)
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}  # Remove 'module.' prefix
model.load_state_dict(state_dict)
model.eval()

def segment_dress(image_np):
    """Detects dress using U²-Net and creates a binary mask."""
    
    # Convert image to tensor
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((320, 320))
    ])
    
    image = Image.fromarray(image_np).convert("RGB")
    input_tensor = transform_pipeline(image).unsqueeze(0)

    # U²-Net inference
    with torch.no_grad():
        output = model(input_tensor)[0][0].squeeze().cpu().numpy()

    # Generate binary mask
    dress_mask = (output > 0.5).astype(np.uint8) * 255
    dress_mask = cv2.resize(dress_mask, (image_np.shape[1], image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    return dress_mask

def remove_background(image_np):
    """Removes background and replaces it with white while keeping the dress."""
    
    # Generate dress mask
    mask = segment_dress(image_np)

    # Make background white
    white_bg = np.ones_like(image_np) * 255  # White background
    segmented_dress = np.where(mask[..., None] > 128, image_np, white_bg)

    return Image.fromarray(segmented_dress)

# Gradio Interface
demo = gr.Interface(
    fn=remove_background,
    inputs=gr.Image(type="numpy", label="Upload Dress Image"),
    outputs=gr.Image(type="pil", label="Dress with White Background"),
    title="Dress Segmentation & Background Removal",
    description="Upload a dress image, and this AI model will detect the dress and replace the background with white."
)

if __name__ == "__main__":
    demo.launch()