gradio_test / app.py
zhiweili
change pipeline to AutoPipelineForInpainting
e800866
import gradio as gr
import numpy as np
import mediapipe as mp
import torch
from PIL import Image
from diffusers import AutoPipelineForInpainting, DPMSolverMultistepScheduler
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=segment_model)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
MASK_CATEGORY = segmenter.labels
base_model = "SG161222/RealVisXL_V4.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = AutoPipelineForInpainting.from_pretrained(
base_model, torch_dtype=torch.float16, use_safetensors=True
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.to(device)
generator = torch.Generator(device).manual_seed(0)
def image_to_image(input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps):
# Generate the output image
output_image = pipeline(
generator=generator,
prompt=prompt,
negative_prompt=negative_prompt,
image=input_image,
mask_image=mask_image,
guidance_scale=guidance_scale,
num_inference_steps = num_inference_steps,
).images[0]
return output_image
def segment_image(input_image, category):
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
segmentation_result = segmenter.segment(image)
category_mask = segmentation_result.category_mask
category_mask_np = category_mask.numpy_view()
target_mask = category_mask_np == MASK_CATEGORY.index(category)
# Generate solid color images for showing the output segmentation mask.
image_data = image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
dilated_mask = binary_dilation(target_mask, iterations=4)
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
output_image = Image.fromarray(output_image)
return output_image
with gr.Blocks() as grApp:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=1, label="Prompt")
negative_prompt = gr.Textbox(lines=2, label="Negative Prompt")
category = gr.Dropdown(label='Mask Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale")
num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps")
input_image = gr.Image(label="Input Image", type="pil")
generate_btn = gr.Button("Generate Image")
with gr.Column():
mask_image = gr.Image(label="Mask Image", type="pil")
with gr.Column():
output_image = gr.Image(label="Output Image", type="pil")
generate_btn.click(
fn=segment_image,
inputs=[input_image, category],
outputs=[mask_image],
).then(
fn=image_to_image,
inputs=[input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps],
outputs=[output_image],
)
grApp.launch()