import numpy as np import gradio as gr import spaces import os import random import torch from PIL import Image import cv2 from huggingface_hub import login from diffusers import FluxControlNetPipeline, FluxControlNetModel from diffusers.models import FluxMultiControlNetModel """ FLUX‑1 ControlNet demo ---------------------- This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model. Key points ~~~~~~~~~~ * Single *control image* input (left). * *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right). * *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength. * Seed handling with optional randomisation. * Advanced sliders for *Guidance scale* and *Inference steps*. * Works on CUDA (bfloat16) or CPU (float32). * Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the other modes). Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call `login("")` explicitly. """ # -------------------------------------------------- # Model & pipeline setup # -------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN_NEW") login(HF_TOKEN) # If you prefer to hard‑code the token, uncomment: # login("hf_your_token_here") BASE_MODEL = "black-forest-labs/FLUX.1-dev" CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 controlnet_single = FluxControlNetModel.from_pretrained( CONTROLNET_MODEL, torch_dtype=dtype ) controlnet = FluxMultiControlNetModel([controlnet_single]) pipe = FluxControlNetPipeline.from_pretrained( BASE_MODEL, controlnet=controlnet, torch_dtype=dtype ).to(device) pipe.set_progress_bar_config(disable=True) # -------------------------------------------------- # UI ‑> model value mapping # -------------------------------------------------- MODE_MAPPING = { "canny": 0, "depth": 1, "openpose": 2, "gray": 3, "blur": 4, "tile": 5, "low quality": 6, } MAX_SEED = 100 # -------------------------------------------------- # Helper: quick‑n‑dirty Canny preview (only for UI display) # -------------------------------------------------- def _preview_canny(pil_img: Image.Image) -> Image.Image: arr = np.array(pil_img.convert("RGB")) edges = cv2.Canny(arr, 100, 200) edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) return Image.fromarray(edges_rgb) def _make_preview(control_image: Image.Image, mode: str) -> Image.Image: if mode == "canny": return _preview_canny(control_image) # For other modes you can plug in your own visualiser later return control_image # -------------------------------------------------- # Inference function # -------------------------------------------------- @spaces.GPU def infer( control_image: Image.Image, prompt: str, mode: str, control_strength: float, seed: int, randomize_seed: bool, guidance_scale: float, num_inference_steps: int, ): if control_image is None: raise gr.Error("Please upload a control image first.") if randomize_seed: seed = random.randint(0, MAX_SEED) gen = torch.Generator(device).manual_seed(seed) w, h = control_image.size result = pipe( prompt=prompt, control_image=[control_image], control_mode=[MODE_MAPPING[mode]], width=w, height=h, controlnet_conditioning_scale=[control_strength], num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=gen, ).images[0] preview = _make_preview(control_image, mode) return result, seed, preview # -------------------------------------------------- # Gradio UI # -------------------------------------------------- css = """#wrapper {max-width: 960px; margin: 0 auto;}""" with gr.Blocks(css=css, elem_id="wrapper") as demo: gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro") gr.Markdown( "A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. " + "Recommended strengths: *canny 0.65*, *tile 0.45*, *depth 0.55*, *blur 0.45*, " + "*openpose 0.55*, *gray 0.45*, *low quality 0.40*. Long prompts usually help." ) # ------------ Image panel row ------------ with gr.Row(): control_image = gr.Image( label="Upload a processed control image", type="pil", height=512, ) result_image = gr.Image(label="Result", height=512) preview_image = gr.Image(label="Pre‑processed Cond", height=512) # ------------ Prompt ------------ prompt_txt = gr.Textbox(label="Prompt", value="best quality", lines=1) # ------------ ControlNet settings ------------ with gr.Row(): with gr.Column(): gr.Markdown("### ControlNet") mode_radio = gr.Radio( choices=list(MODE_MAPPING.keys()), value="gray", label="Mode" ) strength_slider = gr.Slider( 0.0, 1.0, value=0.5, step=0.01, label="control strength" ) with gr.Column(): seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed") randomize_chk = gr.Checkbox(label="Randomize seed", value=True) guidance_slider = gr.Slider( 0.0, 10.0, step=0.1, value=3.5, label="Guidance scale" ) steps_slider = gr.Slider(1, 50, step=1, value=24, label="Inference steps") submit_btn = gr.Button("Submit") submit_btn.click( fn=infer, inputs=[ control_image, prompt_txt, mode_radio, strength_slider, seed_slider, randomize_chk, guidance_slider, steps_slider, ], outputs=[result_image, seed_slider, preview_image], ) if __name__ == "__main__": demo.launch()