new_mmm / app.py
gaur3009's picture
Update app.py
ef61ae7 verified
raw
history blame
2.75 kB
import torch
import cv2
import numpy as np
import gradio as gr
from PIL import Image
from torchvision import transforms
from skimage.restoration import denoise_tv_chambolle
from transformers import SamModel, SamProcessor
# Load SAM model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
def segment_dress(image):
"""Segments the dress from an input image using SAM."""
input_points = [[[image.size[0] // 2, image.size[1] // 2]]]
inputs = processor(image, input_points=input_points, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
return masks[0][0].numpy() if masks else None
def warp_design(design, mask, warp_scale):
"""Warp the design using TPS and scale control."""
h, w = mask.shape[:2]
design_resized = cv2.resize(design, (w, h))
# Ensure mask is in correct format
scaled_mask = (mask * 255 * (warp_scale / 100)).astype(np.uint8)
# Ensure mask is single-channel
if len(scaled_mask.shape) == 3:
scaled_mask = cv2.cvtColor(scaled_mask, cv2.COLOR_BGR2GRAY)
return cv2.bitwise_and(design_resized, design_resized, mask=scaled_mask)
def blend_images(base, overlay, mask):
"""Blends the design onto the dress using seamless cloning."""
center = tuple(np.array(base.shape[:2]) // 2)
return cv2.seamlessClone(overlay, base, mask, center, cv2.NORMAL_CLONE)
def apply_design(image_path, design_path, warp_scale):
"""Pipeline to segment, warp, and blend design onto dress."""
image = Image.open(image_path).convert("RGB")
design = cv2.imread(design_path)
mask = segment_dress(image)
if mask is None:
return "Segmentation Failed!"
warped_design = warp_design(design, mask, warp_scale)
blended = blend_images(np.array(image), warped_design, mask)
return Image.fromarray(blended)
def main(image, design, warp_scale):
return apply_design(image, design, warp_scale)
# Gradio UI
demo = gr.Interface(
fn=main,
inputs=[
gr.Image(type="filepath", label="Upload Dress Image"),
gr.Image(type="filepath", label="Upload Design Image"),
gr.Slider(0, 100, value=50, label="Warp Scale (%)")
],
outputs=gr.Image(label="Warped Design on Dress"),
title="AI-Powered Dress Designer",
description="Upload a dress image and a design pattern. The AI will warp and blend the design onto the dress while preserving natural folds!"
)
demo.launch()