import torch import cv2 import numpy as np import gradio as gr from PIL import Image from torchvision import transforms from skimage.restoration import denoise_tv_chambolle from transformers import SamModel, SamProcessor # Load SAM model DEVICE = "cuda" if torch.cuda.is_available() else "cpu" model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") def segment_dress(image): """Segments the dress from an input image using SAM.""" input_points = [[[image.size[0] // 2, image.size[1] // 2]]] inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) return masks[0][0].numpy() if masks else None def warp_design(design, mask, warp_scale): """Warp the design using TPS and scale control.""" h, w = mask.shape[:2] design_resized = cv2.resize(design, (w, h)) # Apply scaling scaled_mask = (mask * 255 * (warp_scale / 100)).astype(np.uint8) return cv2.bitwise_and(design_resized, design_resized, mask=scaled_mask) def blend_images(base, overlay, mask): """Blends the design onto the dress using seamless cloning.""" center = tuple(np.array(base.shape[:2]) // 2) return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE) def apply_design(image_path, design_path, warp_scale): """Pipeline to segment, warp, and blend design onto dress.""" image = Image.open(image_path).convert("RGB") design = cv2.imread(design_path) mask = segment_dress(image) if mask is None: return "Segmentation Failed!" warped_design = warp_design(design, mask, warp_scale) blended = blend_images(np.array(image), warped_design, mask) return Image.fromarray(blended) def main(image, design, warp_scale): return apply_design(image, design, warp_scale) # Gradio UI demo = gr.Interface( fn=main, inputs=[ gr.Image(type="filepath", label="Upload Dress Image"), gr.Image(type="filepath", label="Upload Design Image"), gr.Slider(0, 100, value=50, label="Warp Scale (%)") ], outputs=gr.Image(label="Warped Design on Dress"), title="AI-Powered Dress Designer", description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!" ) demo.launch()