import base64 import os import pdb import random import sys import time from io import BytesIO import gradio as gr import numpy as np import spaces import torch import torchvision.transforms.functional as TF from PIL import Image from torchvision import transforms from src.img2skt import image_to_sketch_gif from src.model import make_1step_sched from src.pix2pix_turbo import Pix2Pix_Turbo model = Pix2Pix_Turbo("sketch_to_image_stochastic") style_list = [ { "name": "No Style", "prompt": "{prompt}", }, { "name": "Cinematic", "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy", }, { "name": "3D Model", "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting", }, { "name": "Anime", "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed", }, { "name": "Digital Art", "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed", }, { "name": "Photographic", "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed", }, { "name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics", }, { "name": "Fantasy art", "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy", }, { "name": "Neonpunk", "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional", }, { "name": "Manga", "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style", }, ] styles = {k["name"]: k["prompt"] for k in style_list} STYLE_NAMES = list(styles.keys()) DEFAULT_STYLE_NAME = "Manga" MAX_SEED = np.iinfo(np.int32).max HEIGHT = 512 # Display height WIDTH = 512 # Display width PROC_WIDTH = 512 # Processing width PROC_HEIGHT = 512 # Processing height ITER_DELAY = 1.0 # Create a white background image def create_white_background(width, height): return Image.new("RGB", (width, height), color="white") white_background = create_white_background(WIDTH, HEIGHT) def make_button_and_slider_unclickable(): # Disable the button and slider return ( gr.Button(interactive=False), gr.Slider( interactive=False, ), ) def make_button_and_slider_clickable(): # Enable the button and slider return ( gr.Button(interactive=True), gr.Slider( interactive=True, ), ) @spaces.GPU(duration=45) def run(image, prompt, prompt_template, style_name, seed, val_r): image = image["composite"] if image.size != (PROC_WIDTH, PROC_HEIGHT): image = image.resize((PROC_WIDTH, PROC_HEIGHT)) prompt = prompt_template.replace("{prompt}", prompt) image = image.convert("RGB") image = Image.fromarray(255 - np.array(image)) image_t = TF.to_tensor(image) > 0.5 with torch.no_grad(): c_t = image_t.unsqueeze(0).cuda().float() torch.manual_seed(seed) B, C, H, W = c_t.shape noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise) output_pil = TF.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) if output_pil.size != (WIDTH, HEIGHT): output_pil = output_pil.resize((WIDTH, HEIGHT)) return output_pil def clear_image_editor(): return ( {"background": white_background, "layers": None, "composite": None}, gr.Image( value=None, ), gr.Image( value=None, ), gr.State([]), gr.Slider( maximum=1, value=0, interactive=False, ), gr.Button(interactive=False), ) def apply_func_click(frames, frame_selector): # Apply the selected frame to the sketchpad try: selected_frame = frames[int(frame_selector)] return { "background": white_background, "layers": [selected_frame], "composite": None, } except Exception as e: pass def frame_selector_change(frame_idx, frames): try: frame_idx = int(frame_idx) frame = frames[frame_idx] return frame except Exception as e: pass with gr.Blocks() as demo: gr.Markdown("# Sketch to Image Demo") with gr.Row(): with gr.Column(scale=1): image = gr.Sketchpad( value={ "background": white_background, "layers": None, "composite": white_background, }, image_mode="L", type="pil", sources=None, # container=True, label="Sketchpad", show_label=True, show_download_button=True, # show_share_button=True, interactive=True, layers=False, # height="80vw", canvas_size=(WIDTH, HEIGHT), show_fullscreen_button=False, brush=gr.Brush( colors=["#000000"], color_mode="fixed", default_size=4, ), ) prompt = gr.Textbox(label="Prompt", value="", show_label=True) with gr.Row(): run_button = gr.Button("Run", scale=1) randomize_seed = gr.Button("Random", scale=1, visible=False) gr.Markdown( """ ### Instructions 1. Enter a text prompt (e.g. cat). 2. Draw some sketches on the Sketchpad. 3. Click on the **Run** button to generate image in the Final Image. 4. You may then select a frame by the Frame Selector and click on **Apply** to apply the selected frame to the Sketchpad. 5. You may then modify the sketches and click on **Run** again to generate new images. 6. Repeat steps 4 and 5 to generate new images until you are satisfied with the result. 7. To restart from scratch, click on the **Bin Icon** on the top right corner of the sketchpad **Thanks to the [paper](https://arxiv.org/abs/2403.12036) and their open-sourced models!** """ ) frame_result = gr.Image( height=HEIGHT, width=WIDTH, label="Sketch Outputs", type="pil", show_label=True, show_download_button=True, interactive=False, visible=False, ) apply_button = gr.Button("Apply", scale=1, visible=False, interactive=False) frame_selector = gr.Slider( minimum=0, maximum=1, value=0, step=1, visible=False, interactive=False, scale=4, label="Frame Selector", ) with gr.Column(scale=1): result = gr.Image( height=HEIGHT, width=WIDTH, label="Final Image", type="pil", show_label=True, show_download_button=True, interactive=False, visible=True, ) # invisible elements style = gr.Dropdown( label="Style", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, scale=1, visible=False, ) prompt_temp = gr.Textbox( label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], max_lines=1, scale=2, visible=False, ) val_r = gr.Slider( label="Sketch guidance: ", show_label=True, minimum=0, maximum=1, value=0.5, step=0.01, scale=4, visible=False, ) seed = gr.Textbox(label="Seed", value=42, scale=4, visible=False) frames = gr.State([]) sketches = gr.Image( height=HEIGHT, width=WIDTH, show_label=False, show_download_button=True, type="pil", visible=False, ) one_frame = gr.Image( height=HEIGHT, width=WIDTH, show_label=False, show_download_button=True, type="pil", interactive=False, visible=False, ) inputs = [image, prompt, prompt_temp, style, seed, val_r] outputs = [result] randomize_seed_click = randomize_seed.click( lambda: random.randint(0, MAX_SEED), inputs=[], outputs=seed, ).then( fn=run, inputs=inputs, outputs=outputs, ) # prompt_submit = ( # prompt.submit( # make_button_and_slider_unclickable, # inputs=None, # outputs=[apply_button, frame_selector], # ) # .then(fn=run, inputs=inputs, outputs=outputs) # .then( # image_to_sketch_gif, # inputs=[result], # outputs=[frame_result, frames, frame_selector, apply_button], # ) # .then( # fn=make_button_and_slider_clickable, # inputs=None, # outputs=[apply_button, frame_selector], # ) # ) style_change = style.change( lambda x: styles[x], inputs=[style], outputs=[prompt_temp] ).then( fn=run, inputs=inputs, outputs=outputs, ) val_r_change = val_r.change(run, inputs=inputs, outputs=outputs) run_button_click = run_button.click(fn=run, inputs=inputs, outputs=outputs) # image_apply = ( # image.apply( # fn=make_button_and_slider_unclickable, # inputs=None, # outputs=[apply_button, frame_selector], # ) # .then( # run, # inputs=inputs, # outputs=outputs, # ) # .then( # image_to_sketch_gif, # inputs=[result], # outputs=[frame_result, frames, frame_selector, apply_button], # ) # .then( # fn=make_button_and_slider_clickable, # inputs=None, # outputs=[apply_button, frame_selector], # ) # ) # image.clear( fn=None, inputs=None, outputs=None, cancels=[ run_button_click, randomize_seed_click, style_change, val_r_change, ], ) image.clear( fn=clear_image_editor, inputs=None, outputs=[ image, result, frame_result, frames, frame_selector, apply_button, ], ) if __name__ == "__main__": demo.queue().launch()