import numpy as np import gradio as gr import spaces import os import random import subprocess import torch from PIL import Image import cv2 from huggingface_hub import login from diffusers import FluxControlNetPipeline, FluxControlNetModel from diffusers.models import FluxMultiControlNetModel import warnings from typing import Tuple """ FLUX‑1 ControlNet demo ---------------------- This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model. Key points ~~~~~~~~~~ * Single *control image* input (left). * *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right). * *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength. * Seed handling with optional randomisation. * Advanced sliders for *Guidance scale* and *Inference steps*. * Works on CUDA (bfloat16) or CPU (float32). * Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the other modes). Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call `login("")` explicitly. """ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) # -------------------------------------------------- # Model & pipeline setup # -------------------------------------------------- HF_TOKEN = os.getenv("HF_TOKEN_NEW") login(HF_TOKEN) # If you prefer to hard‑code the token, uncomment: # login("hf_your_token_here") BASE_MODEL = "black-forest-labs/FLUX.1-dev" CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro" device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 controlnet_single = FluxControlNetModel.from_pretrained( CONTROLNET_MODEL, torch_dtype=dtype ) controlnet = FluxMultiControlNetModel([controlnet_single]) pipe = FluxControlNetPipeline.from_pretrained( BASE_MODEL, controlnet=controlnet, torch_dtype=dtype ).to(device) pipe.set_progress_bar_config(disable=True) # -------------------------------------------------- # UI ‑> model value mapping # -------------------------------------------------- MODE_MAPPING = { "canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "low quality": 6, } MAX_SEED = 100 # ----------------------------------------------------------------------------- # Preview helpers – one small, self‑contained function per mode # ----------------------------------------------------------------------------- def _preview_canny( pil_img: Image.Image, canny_threshold_1: int, canny_threshold_2: int ) -> Image.Image: """Fast Canny‑edge preview (already implemented).""" arr = np.array(pil_img.convert("RGB")) edges = cv2.Canny(arr, threshold1=canny_threshold_1, threshold2=canny_threshold_2) edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) return Image.fromarray(edges_rgb) # ――― tile ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # def _preview_tile(pil_img: Image.Image, grid: Tuple[int, int] = (2, 2)) -> Image.Image: """Replicates *pil_img* into an *n×m* tiled grid (default 2×2). This offers a quick visual hint of what a *tiling* control mode will do (repeatable textures, etc.).""" cols, rows = grid img_rgb = pil_img.convert("RGB") w, h = img_rgb.size tiled = Image.new("RGB", (w * cols, h * rows)) for c in range(cols): for r in range(rows): tiled.paste(img_rgb, (c * w, r * h)) return tiled # ――― depth ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # def _preview_depth(pil_img: Image.Image) -> Image.Image: """Very rough *depth* proxy using the Laplacian and a colormap. ▸ Convert to gray ▸ Run Laplacian to highlight depth‑like gradients ▸ Apply a TURBO colormap to mimic depth heat‑map appearance""" gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) lap = cv2.Laplacian(gray, cv2.CV_16S, ksize=3) depth = cv2.convertScaleAbs(lap) depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO) return Image.fromarray(depth_color) # ――― blur ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # def _preview_blur(pil_img: Image.Image, ksize: int = 15) -> Image.Image: """Gaussian blur preview. A single, relatively large kernel is enough for UI illustration.""" if ksize % 2 == 0: ksize += 1 # kernel must be odd blurred = cv2.GaussianBlur(np.array(pil_img), (ksize, ksize), sigmaX=0) return Image.fromarray(blurred) # ――― pose ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # def _preview_pose(pil_img: Image.Image) -> Image.Image: """Attempt a lightweight 2‑D pose overlay using *mediapipe* if available. If *mediapipe* is not installed (or CPU inference fails), we gracefully fallback to an edge‑map preview so the UI never crashes.""" try: import mediapipe as mp # type: ignore mp_pose = mp.solutions.pose mp_drawing = mp.solutions.drawing_utils img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) with mp_pose.Pose(static_image_mode=True) as pose_estimator: results = pose_estimator.process( img_bgr[..., ::-1] ) # Mediapipe expects RGB annotated = img_bgr.copy() if results.pose_landmarks: mp_drawing.draw_landmarks( annotated, results.pose_landmarks, mp_pose.POSE_CONNECTIONS ) annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) return Image.fromarray(annotated_rgb) except Exception as exc: # pragma: no cover – any import / runtime error warnings.warn( f"Pose preview failed ({exc!s}); falling back to Canny.", RuntimeWarning ) # Return an edge map as a sensible fallback rather than exploding the UI return _preview_canny(pil_img, 100, 200) # ――― gray ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # def _preview_gray(pil_img: Image.Image) -> Image.Image: """Simple grayscale conversion, but keep a 3‑channel RGB image so the UI widget pipeline stays consistent.""" gray = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2GRAY) gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) return Image.fromarray(gray_rgb) # ――― low quality ――――――――――――――――――――――――――――――――――――――――――――――――――――――――― # def _preview_low_quality(pil_img: Image.Image, factor: int = 8) -> Image.Image: """Mimic a low‑quality thumbnail: aggressively downsample then upscale. The default *factor* (8×) is chosen to make artefacts obvious.""" img_rgb = pil_img.convert("RGB") w, h = img_rgb.size small = img_rgb.resize((max(1, w // factor), max(1, h // factor)), Image.BILINEAR) low_q = small.resize( (w, h), Image.NEAREST ) # upsample w/ Nearest to exaggerate blocks return low_q # ----------------------------------------------------------------------------- # Master dispatch # ----------------------------------------------------------------------------- def _make_preview( control_image: Image.Image, mode: str, canny_threshold_1: int = 100, canny_threshold_2: int = 200, ) -> Image.Image: """Return a *quick‑n‑dirty* preview image for the requested *mode*. Parameters ---------- control_image : PIL.Image The input image selected by the user. mode : str One of the keys of :data:`MODE_MAPPING`. canny_threshold_1 / 2 : int, optional Only used if *mode* is "canny" (passed straight to OpenCV Canny). """ mode = mode.lower() if mode not in MODE_MAPPING: warnings.warn(f"Unknown preview mode '{mode}'. Returning untouched image.") return control_image if mode == "canny": return _preview_canny(control_image, canny_threshold_1, canny_threshold_2) if mode == "tile": return _preview_tile(control_image) if mode == "depth": return _preview_depth(control_image) if mode == "blur": return _preview_blur(control_image) if mode == "pose": return _preview_pose(control_image) if mode == "gray": return _preview_gray(control_image) if mode == "low quality": return _preview_low_quality(control_image) # Fallback – should never happen due to early mode check return control_image # -------------------------------------------------- # Inference function # -------------------------------------------------- @spaces.GPU def infer( control_image: Image.Image, prompt: str, mode: str, control_strength: float, seed: int, randomize_seed: bool, guidance_scale: float, num_inference_steps: int, canny_threshold_1: int, canny_threshold_2: int, ): if control_image is None: raise gr.Error("Please upload a control image first.") if randomize_seed: seed = random.randint(0, MAX_SEED) gen = torch.Generator(device).manual_seed(seed) w, h = control_image.size preprocessed = _make_preview( control_image, mode, canny_threshold_1, canny_threshold_2 ) result = pipe( prompt=prompt, control_image=[preprocessed], control_mode=[MODE_MAPPING[mode]], width=w, height=h, controlnet_conditioning_scale=[control_strength], num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=gen, ).images[0] return result, seed, preprocessed # -------------------------------------------------- # Gradio UI # -------------------------------------------------- css = """#wrapper {max-width: 960px; margin: 0 auto;}""" with gr.Blocks(css=css, elem_id="wrapper") as demo: gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro by Frank") gr.Markdown( "A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. " + "Recommended strengths: *canny 0.76*. Long prompts usually help." ) # ------------ Image panel row ------------ with gr.Row(): control_image = gr.Image( label="Upload animage", type="pil", height=512 + 256, ) result_image = gr.Image(label="Result", height=512 + 256) preview_image = gr.Image(label="Pre‑processed Cond", height=512 + 256) # ------------ Prompt ------------ prompt_txt = gr.Textbox(label="Prompt", value="White background", lines=1) # ------------ ControlNet settings ------------ with gr.Row(): with gr.Column(): gr.Markdown("### ControlNet") mode_radio = gr.Radio( choices=list(MODE_MAPPING.keys()), value="canny", label="Mode" ) strength_slider = gr.Slider( 0.0, 1.0, value=0.76, step=0.01, label="control strength" ) gr.Markdown("### Preprocess") canny_threshold_1 = gr.Slider( 0, 500, step=1, value=100, label="Canny threshold 1" ) canny_threshold_2 = gr.Slider( 0, 500, step=1, value=200, label="Canny threshold 2" ) with gr.Column(): seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed") randomize_chk = gr.Checkbox(label="Randomize seed", value=False) guidance_slider = gr.Slider( 0.0, 10.0, step=0.1, value=3.5, label="Guidance scale" ) steps_slider = gr.Slider(1, 50, step=1, value=50, label="Inference steps") submit_btn = gr.Button("Submit") submit_btn.click( fn=infer, inputs=[ control_image, prompt_txt, mode_radio, strength_slider, seed_slider, randomize_chk, guidance_slider, steps_slider, canny_threshold_1, canny_threshold_2, ], outputs=[result_image, seed_slider, preview_image], ) if __name__ == "__main__": demo.launch()