Ramzan0553's picture
Create app.py
cbc4569 verified
import gradio as gr
from transformers import pipeline, AutoImageProcessor, MobileNetV2ForSemanticSegmentation
from PIL import Image
import numpy as np
import cv2
# Load segmentation pipeline and model
pipe = pipeline("image-segmentation", model="google/deeplabv3_mobilenet_v2_1.0_513")
processor = AutoImageProcessor.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513")
model = MobileNetV2ForSemanticSegmentation.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513")
def VirtualTryOn(user_image, clothing_image):
if isinstance(user_image, np.ndarray):
user_image = Image.fromarray(user_image)
if isinstance(clothing_image, np.ndarray):
clothing_image = Image.fromarray(clothing_image)
if isinstance(user_image, Image.Image) and isinstance(clothing_image, Image.Image):
user_segmentation = pipe(user_image)
user_mask = user_segmentation[0]['mask']
clothing_segmentation = pipe(clothing_image)
clothing_mask = clothing_segmentation[0]['mask']
user_mask_array = np.array(user_mask)
clothing_mask_array = np.array(clothing_mask)
user_image_array = np.array(user_image)
clothing_image_array = np.array(clothing_image)
user_isolated = cv2.bitwise_and(user_image_array, user_image_array, mask=user_mask_array)
clothing_isolated = cv2.bitwise_and(clothing_image_array, clothing_image_array, mask=clothing_mask_array)
user_height, user_width, _ = user_isolated.shape
clothing_resized = cv2.resize(clothing_isolated, (user_width, user_height))
combined_image = cv2.add(user_isolated, clothing_resized)
combined_image = Image.fromarray(combined_image)
return combined_image
else:
raise ValueError("Both inputs should be PIL images")
css = """
.gradio-container {
background-color: #f5f5f5;
font-family: 'Arial', sans-serif;
padding: 20px;
border-radius: 15px;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
width: 90vw;
max-width: 1200px;
margin: auto;
}
h1 {
color: #333333;
text-align: center;
font-size: 2.5rem;
margin-bottom: 20px;
}
#images-container {
display: flex;
justify-content: space-around;
align-items: center;
gap: 20px;
padding: 15px;
border: 2px solid #cccccc;
border-radius: 15px;
background-color: #ffffff;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
.image-container {
display: flex;
flex-direction: column;
align-items: center;
gap: 10px;
}
.image-container label {
font-weight: bold;
color: #555555;
}
.image-box {
width: 220px;
height: 300px;
border: 3px dashed #aaaaaa;
border-radius: 10px;
display: flex;
justify-content: center;
align-items: center;
background-color: #f9f9f9;
}
button {
font-size: 1.2rem;
padding: 10px 20px;
border-radius: 10px;
border: none;
cursor: pointer;
transition: all 0.3s ease;
}
#try-on-button {
background-color: #4CAF50;
color: white;
}
#try-on-button:hover {
background-color: #45a049;
}
#clear-button {
background-color: #FF5722;
color: white;
}
#clear-button:hover {
background-color: #e64a19;
}
"""
with gr.Blocks(css=css) as iface:
gr.Markdown("<h1>Virtual Try-On Application</h1>")
with gr.Row(elem_id="images-container"):
with gr.Column(elem_id="user-image-container", elem_classes="image-container"):
gr.Markdown("**Upload Person Image**")
user_image = gr.Image(type="pil", label="Person Image", elem_id="user-image", elem_classes="image-box")
with gr.Column(elem_id="clothing-image-container", elem_classes="image-container"):
gr.Markdown("**Upload Clothing Image**")
clothing_image = gr.Image(type="pil", label="Clothing Image", elem_id="clothing-image", elem_classes="image-box")
with gr.Column(elem_id="output-image-container", elem_classes="image-container"):
gr.Markdown("**Try-On Result**")
output = gr.Image(type="pil", label="Result", elem_id="output", elem_classes="image-box")
with gr.Row():
with gr.Column():
try_on_button = gr.Button("Try On", elem_id="try-on-button")
with gr.Column():
clear_button = gr.Button("Clear", elem_id="clear-button")
try_on_button.click(fn=VirtualTryOn, inputs=[user_image, clothing_image], outputs=output)
clear_button.click(fn=lambda: (None, None, None), inputs=[], outputs=[user_image, clothing_image, output])
iface.launch()