import gradio as gr import numpy as np import mediapipe as mp import torch from PIL import Image from diffusers import AutoPipelineForInpainting, DPMSolverMultistepScheduler from mediapipe.tasks import python from mediapipe.tasks.python import vision from scipy.ndimage import binary_dilation BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white segment_model = "checkpoints/selfie_multiclass_256x256.tflite" base_options = python.BaseOptions(model_asset_path=segment_model) options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) segmenter = vision.ImageSegmenter.create_from_options(options) MASK_CATEGORY = segmenter.labels base_model = "SG161222/RealVisXL_V4.0" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pipeline = AutoPipelineForInpainting.from_pretrained( base_model, torch_dtype=torch.float16, use_safetensors=True ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.to(device) generator = torch.Generator(device).manual_seed(0) def image_to_image(input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps): # Generate the output image output_image = pipeline( generator=generator, prompt=prompt, negative_prompt=negative_prompt, image=input_image, mask_image=mask_image, guidance_scale=guidance_scale, num_inference_steps = num_inference_steps, ).images[0] return output_image def segment_image(input_image, category): image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image)) segmentation_result = segmenter.segment(image) category_mask = segmentation_result.category_mask category_mask_np = category_mask.numpy_view() target_mask = category_mask_np == MASK_CATEGORY.index(category) # Generate solid color images for showing the output segmentation mask. image_data = image.numpy_view() fg_image = np.zeros(image_data.shape, dtype=np.uint8) fg_image[:] = MASK_COLOR bg_image = np.zeros(image_data.shape, dtype=np.uint8) bg_image[:] = BG_COLOR dilated_mask = binary_dilation(target_mask, iterations=4) condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2 output_image = np.where(condition, fg_image, bg_image) output_image = Image.fromarray(output_image) return output_image with gr.Blocks() as grApp: with gr.Row(): with gr.Column(): prompt = gr.Textbox(lines=1, label="Prompt") negative_prompt = gr.Textbox(lines=2, label="Negative Prompt") category = gr.Dropdown(label='Mask Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1]) guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale") num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps") input_image = gr.Image(label="Input Image", type="pil") generate_btn = gr.Button("Generate Image") with gr.Column(): mask_image = gr.Image(label="Mask Image", type="pil") with gr.Column(): output_image = gr.Image(label="Output Image", type="pil") generate_btn.click( fn=segment_image, inputs=[input_image, category], outputs=[mask_image], ).then( fn=image_to_image, inputs=[input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps], outputs=[output_image], ) grApp.launch()