Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import mediapipe as mp | |
import torch | |
from PIL import Image | |
from diffusers import AutoPipelineForInpainting, DPMSolverMultistepScheduler | |
from mediapipe.tasks import python | |
from mediapipe.tasks.python import vision | |
from scipy.ndimage import binary_dilation | |
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black | |
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white | |
segment_model = "checkpoints/selfie_multiclass_256x256.tflite" | |
base_options = python.BaseOptions(model_asset_path=segment_model) | |
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) | |
segmenter = vision.ImageSegmenter.create_from_options(options) | |
MASK_CATEGORY = segmenter.labels | |
base_model = "SG161222/RealVisXL_V4.0" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipeline = AutoPipelineForInpainting.from_pretrained( | |
base_model, torch_dtype=torch.float16, use_safetensors=True | |
) | |
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) | |
pipeline.to(device) | |
generator = torch.Generator(device).manual_seed(0) | |
def image_to_image(input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps): | |
# Generate the output image | |
output_image = pipeline( | |
generator=generator, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=input_image, | |
mask_image=mask_image, | |
guidance_scale=guidance_scale, | |
num_inference_steps = num_inference_steps, | |
).images[0] | |
return output_image | |
def segment_image(input_image, category): | |
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) | |
segmentation_result = segmenter.segment(image) | |
category_mask = segmentation_result.category_mask | |
category_mask_np = category_mask.numpy_view() | |
target_mask = category_mask_np == MASK_CATEGORY.index(category) | |
# Generate solid color images for showing the output segmentation mask. | |
image_data = image.numpy_view() | |
fg_image = np.zeros(image_data.shape, dtype=np.uint8) | |
fg_image[:] = MASK_COLOR | |
bg_image = np.zeros(image_data.shape, dtype=np.uint8) | |
bg_image[:] = BG_COLOR | |
dilated_mask = binary_dilation(target_mask, iterations=4) | |
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2 | |
output_image = np.where(condition, fg_image, bg_image) | |
output_image = Image.fromarray(output_image) | |
return output_image | |
with gr.Blocks() as grApp: | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(lines=1, label="Prompt") | |
negative_prompt = gr.Textbox(lines=2, label="Negative Prompt") | |
category = gr.Dropdown(label='Mask Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1]) | |
guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale") | |
num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps") | |
input_image = gr.Image(label="Input Image", type="pil") | |
generate_btn = gr.Button("Generate Image") | |
with gr.Column(): | |
mask_image = gr.Image(label="Mask Image", type="pil") | |
with gr.Column(): | |
output_image = gr.Image(label="Output Image", type="pil") | |
generate_btn.click( | |
fn=segment_image, | |
inputs=[input_image, category], | |
outputs=[mask_image], | |
).then( | |
fn=image_to_image, | |
inputs=[input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps], | |
outputs=[output_image], | |
) | |
grApp.launch() | |