import gradio as gr import tempfile import os from utils.ui_utils import * CANVAS_SIZE = 400 DEFAULT_GEN_SIZE = 512 def create_interface(): with gr.Blocks() as app: # Add main title and project link gr.Markdown("# A simplified implementation of Inpaint4Drag") gr.Markdown("## [Visit our project page for more examples](https://visual-ai.github.io/inpaint4drag)") # State variables state = { 'canvas_size': gr.Number(value=CANVAS_SIZE, visible=False, precision=0), 'gen_size': gr.Number(value=DEFAULT_GEN_SIZE, visible=False, precision=0), 'points_list': gr.State(value=[]), 'inpaint_mask': gr.State(value=None) } with gr.Tab(label='Inpaint4Drag'): with gr.Row(): # Draw Region Column with gr.Column(): gr.Markdown("""
1. Draw Regions
""") # Use ImageEditor for newer Gradio versions, fallback to Image with brush try: canvas = gr.ImageEditor( label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") ) except: # Fallback for older Gradio versions canvas = gr.Image( type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, sources=["upload", "webcam", "clipboard"] ) with gr.Row(): fit_btn = gr.Button("Resize Image") # Control Points Column with gr.Column(): gr.Markdown("""2. Control Points
""") input_img = gr.Image( type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=True ) with gr.Row(): undo_btn = gr.Button("Undo Point") clear_btn = gr.Button("Clear Points") # Results Column with gr.Column(): gr.Markdown("""Results
""") output_img = gr.Image( type="numpy", label=" ", height=CANVAS_SIZE, width=CANVAS_SIZE, interactive=False ) with gr.Row(): run_btn = gr.Button("Inpaint") reset_btn = gr.Button("Reset All") # Generation Parameters with gr.Row(): inpaint_ks = gr.Slider( minimum=0, maximum=25, value=5, step=1, label='How much to expand inpainting mask', interactive=True ) setup_events( components={ 'canvas': canvas, 'input_img': input_img, 'output_img': output_img, 'inpaint_ks': inpaint_ks, }, state=state, buttons={ 'fit': fit_btn, 'undo': undo_btn, 'clear': clear_btn, 'run': run_btn, 'reset': reset_btn } ) return app def setup_events(components, state, buttons): # Reset and clear events def setup_reset_events(): buttons['reset'].click( clear_all, [state['canvas_size']], [components['canvas'], components['input_img'], components['output_img'], state['points_list'], components['inpaint_ks'], state['inpaint_mask']] ) components['canvas'].clear( clear_all, [state['canvas_size']], [components['canvas'], components['input_img'], components['output_img'], state['points_list'], components['inpaint_ks'], state['inpaint_mask']] ) # Image manipulation events def setup_image_events(): buttons['fit'].click( clear_point, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['input_img']] ).then( resize, [components['canvas'], state['gen_size'], state['canvas_size']], [components['canvas'], components['input_img'], components['output_img']] ) # Canvas interaction events def setup_canvas_events(): # Handle both ImageEditor and Image events canvas_event = components['canvas'].change if hasattr(components['canvas'], 'change') else components['canvas'].edit canvas_event( visualize_user_drag, [components['canvas'], state['points_list']], [components['input_img']] ).then( preview_out_image, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['output_img'], state['inpaint_mask']] ) components['inpaint_ks'].change( visualize_user_drag, [components['canvas'], state['points_list']], [components['input_img']] ).then( preview_out_image, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['output_img'], state['inpaint_mask']] ) # Input image events def setup_input_events(): components['input_img'].select( add_point, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['input_img']] ).then( preview_out_image, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['output_img'], state['inpaint_mask']] ) # Point manipulation events def setup_point_events(): buttons['undo'].click( undo_point, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['input_img']] ).then( preview_out_image, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['output_img'], state['inpaint_mask']] ) buttons['clear'].click( clear_point, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['input_img']] ).then( preview_out_image, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['output_img'], state['inpaint_mask']] ) # Processing events def setup_processing_events(): buttons['run'].click( preview_out_image, [components['canvas'], state['points_list'], components['inpaint_ks']], [components['output_img'], state['inpaint_mask']] ).then( inpaint, [components['output_img'], state['inpaint_mask']], [components['output_img']] ) # Setup all events setup_reset_events() setup_image_events() setup_canvas_events() setup_input_events() setup_point_events() setup_processing_events() def main(): app = create_interface() # HF Space compatible launch app.queue().launch() if __name__ == '__main__': main()