Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import tempfile | |
import os | |
from utils.ui_utils import * | |
CANVAS_SIZE = 400 | |
DEFAULT_GEN_SIZE = 512 | |
def create_interface(): | |
with gr.Blocks() as app: | |
# Add main title and project link | |
gr.Markdown("# A simplified implementation of Inpaint4Drag") | |
gr.Markdown("## [Visit our project page for more examples](https://visual-ai.github.io/inpaint4drag)") | |
# State variables | |
state = { | |
'canvas_size': gr.Number(value=CANVAS_SIZE, visible=False, precision=0), | |
'gen_size': gr.Number(value=DEFAULT_GEN_SIZE, visible=False, precision=0), | |
'points_list': gr.State(value=[]), | |
'inpaint_mask': gr.State(value=None) | |
} | |
with gr.Tab(label='Inpaint4Drag'): | |
with gr.Row(): | |
# Draw Region Column | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">1. Draw Regions</p>""") | |
# Use ImageEditor for newer Gradio versions, fallback to Image with brush | |
try: | |
canvas = gr.ImageEditor( | |
label=" ", | |
height=CANVAS_SIZE, | |
width=CANVAS_SIZE, | |
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") | |
) | |
except: | |
# Fallback for older Gradio versions | |
canvas = gr.Image( | |
type="numpy", | |
label=" ", | |
height=CANVAS_SIZE, | |
width=CANVAS_SIZE, | |
sources=["upload", "webcam", "clipboard"] | |
) | |
with gr.Row(): | |
fit_btn = gr.Button("Resize Image") | |
# Control Points Column | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">2. Control Points</p>""") | |
input_img = gr.Image( | |
type="numpy", | |
label=" ", | |
height=CANVAS_SIZE, | |
width=CANVAS_SIZE, | |
interactive=True | |
) | |
with gr.Row(): | |
undo_btn = gr.Button("Undo Point") | |
clear_btn = gr.Button("Clear Points") | |
# Results Column | |
with gr.Column(): | |
gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""") | |
output_img = gr.Image( | |
type="numpy", | |
label=" ", | |
height=CANVAS_SIZE, | |
width=CANVAS_SIZE, | |
interactive=False | |
) | |
with gr.Row(): | |
run_btn = gr.Button("Inpaint") | |
reset_btn = gr.Button("Reset All") | |
# Generation Parameters | |
with gr.Row(): | |
inpaint_ks = gr.Slider( | |
minimum=0, | |
maximum=25, | |
value=5, | |
step=1, | |
label='How much to expand inpainting mask', | |
interactive=True | |
) | |
setup_events( | |
components={ | |
'canvas': canvas, | |
'input_img': input_img, | |
'output_img': output_img, | |
'inpaint_ks': inpaint_ks, | |
}, | |
state=state, | |
buttons={ | |
'fit': fit_btn, | |
'undo': undo_btn, | |
'clear': clear_btn, | |
'run': run_btn, | |
'reset': reset_btn | |
} | |
) | |
return app | |
def setup_events(components, state, buttons): | |
# Reset and clear events | |
def setup_reset_events(): | |
buttons['reset'].click( | |
clear_all, | |
[state['canvas_size']], | |
[components['canvas'], components['input_img'], components['output_img'], | |
state['points_list'], components['inpaint_ks'], state['inpaint_mask']] | |
) | |
components['canvas'].clear( | |
clear_all, | |
[state['canvas_size']], | |
[components['canvas'], components['input_img'], components['output_img'], | |
state['points_list'], components['inpaint_ks'], state['inpaint_mask']] | |
) | |
# Image manipulation events | |
def setup_image_events(): | |
buttons['fit'].click( | |
clear_point, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['input_img']] | |
).then( | |
resize, | |
[components['canvas'], state['gen_size'], state['canvas_size']], | |
[components['canvas'], components['input_img'], components['output_img']] | |
) | |
# Canvas interaction events | |
def setup_canvas_events(): | |
# Handle both ImageEditor and Image events | |
canvas_event = components['canvas'].change if hasattr(components['canvas'], 'change') else components['canvas'].edit | |
canvas_event( | |
visualize_user_drag, | |
[components['canvas'], state['points_list']], | |
[components['input_img']] | |
).then( | |
preview_out_image, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['output_img'], state['inpaint_mask']] | |
) | |
components['inpaint_ks'].change( | |
visualize_user_drag, | |
[components['canvas'], state['points_list']], | |
[components['input_img']] | |
).then( | |
preview_out_image, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['output_img'], state['inpaint_mask']] | |
) | |
# Input image events | |
def setup_input_events(): | |
components['input_img'].select( | |
add_point, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['input_img']] | |
).then( | |
preview_out_image, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['output_img'], state['inpaint_mask']] | |
) | |
# Point manipulation events | |
def setup_point_events(): | |
buttons['undo'].click( | |
undo_point, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['input_img']] | |
).then( | |
preview_out_image, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['output_img'], state['inpaint_mask']] | |
) | |
buttons['clear'].click( | |
clear_point, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['input_img']] | |
).then( | |
preview_out_image, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['output_img'], state['inpaint_mask']] | |
) | |
# Processing events | |
def setup_processing_events(): | |
buttons['run'].click( | |
preview_out_image, | |
[components['canvas'], state['points_list'], components['inpaint_ks']], | |
[components['output_img'], state['inpaint_mask']] | |
).then( | |
inpaint, | |
[components['output_img'], state['inpaint_mask']], | |
[components['output_img']] | |
) | |
# Setup all events | |
setup_reset_events() | |
setup_image_events() | |
setup_canvas_events() | |
setup_input_events() | |
setup_point_events() | |
setup_processing_events() | |
def main(): | |
app = create_interface() | |
# HF Space compatible launch | |
app.queue().launch() | |
if __name__ == '__main__': | |
main() |