|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
from segment_anything import sam_model_registry, SamPredictor |
|
from diffusers import StableDiffusionInpaintPipeline |
|
from PIL import Image |
|
|
|
|
|
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth").to("cpu") |
|
predictor = SamPredictor(sam) |
|
|
|
|
|
pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-2-inpainting" |
|
).to("cpu") |
|
|
|
def change_dress(image, color): |
|
"""Segment dress using SAM and recolor using Stable Diffusion.""" |
|
image = np.array(image) |
|
|
|
|
|
predictor.set_image(image) |
|
masks, _, _ = predictor.predict(point_coords=np.array([[200, 200]]), point_labels=np.array([1]), multimask_output=False) |
|
|
|
|
|
mask = Image.fromarray(masks[0].astype(np.uint8) * 255) |
|
|
|
|
|
result = pipe(prompt=f"A dress of {color}", image=Image.fromarray(image), mask_image=mask).images[0] |
|
|
|
return result |
|
|
|
|
|
interface = gr.Interface( |
|
fn=change_dress, |
|
inputs=[gr.Image(type="numpy"), gr.ColorPicker(label="Choose dress color")], |
|
outputs=gr.Image(), |
|
title="Fast Visual Try-On", |
|
description="Segment the dress and change its color in under 20 seconds using AI." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |