File size: 3,626 Bytes
2ceb133
7742553
 
2ceb133
7742553
 
e800866
7742553
 
 
 
 
 
 
 
 
 
 
 
2ceb133
 
 
 
e800866
2ceb133
 
 
 
 
 
e800866
2ceb133
 
 
7742553
 
 
 
 
 
2ceb133
 
 
 
e800866
7742553
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956c1ae
7742553
 
 
 
 
 
 
 
 
 
 
 
 
 
2ceb133
7742553
e800866
7742553
 
2ceb133
e800866
7742553
2ceb133
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import numpy as np
import mediapipe as mp
import torch

from PIL import Image
from diffusers import AutoPipelineForInpainting, DPMSolverMultistepScheduler
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation

BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white

segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=segment_model)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
MASK_CATEGORY = segmenter.labels

base_model = "SG161222/RealVisXL_V4.0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipeline = AutoPipelineForInpainting.from_pretrained(
    base_model, torch_dtype=torch.float16, use_safetensors=True
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.to(device)
generator = torch.Generator(device).manual_seed(0)

def image_to_image(input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps):
    # Generate the output image
    output_image = pipeline(
        generator=generator,
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=input_image, 
        mask_image=mask_image,
        guidance_scale=guidance_scale, 
        num_inference_steps = num_inference_steps,
    ).images[0]
    
    return output_image

def segment_image(input_image, category):
    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask_np = category_mask.numpy_view()
    target_mask = category_mask_np == MASK_CATEGORY.index(category)

    # Generate solid color images for showing the output segmentation mask.
    image_data = image.numpy_view()
    fg_image = np.zeros(image_data.shape, dtype=np.uint8)
    fg_image[:] = MASK_COLOR
    bg_image = np.zeros(image_data.shape, dtype=np.uint8)
    bg_image[:] = BG_COLOR

    dilated_mask = binary_dilation(target_mask, iterations=4)
    condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2

    output_image = np.where(condition, fg_image, bg_image)
    output_image = Image.fromarray(output_image)
    return output_image

with gr.Blocks() as grApp:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(lines=1, label="Prompt")
            negative_prompt = gr.Textbox(lines=2, label="Negative Prompt")
            category = gr.Dropdown(label='Mask Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
            guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale")
            num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps")
            input_image = gr.Image(label="Input Image", type="pil")
            generate_btn = gr.Button("Generate Image")
        with gr.Column():
            mask_image = gr.Image(label="Mask Image", type="pil")
        with gr.Column():
            output_image = gr.Image(label="Output Image", type="pil")
    
    generate_btn.click(
        fn=segment_image,
        inputs=[input_image, category],
        outputs=[mask_image],
    ).then(
        fn=image_to_image,
        inputs=[input_image, mask_image, prompt, negative_prompt, guidance_scale, num_inference_steps],
        outputs=[output_image],
    )

grApp.launch()