ControlNet / app.py
FrankFacundo's picture
WIP
d76eab0
raw
history blame
13.1 kB
import numpy as np
import gradio as gr
import spaces
import os
import random
import subprocess
import torch
from PIL import Image
import cv2
from huggingface_hub import login
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
import warnings
from typing import Tuple
"""
FLUX‑1 ControlNet demo
----------------------
This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload
slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model.
Key points
~~~~~~~~~~
* Single *control image* input (left).
* *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right).
* *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength.
* Seed handling with optional randomisation.
* Advanced sliders for *Guidance scale* and *Inference steps*.
* Works on CUDA (bfloat16) or CPU (float32).
* Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the
other modes).
Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call
`login("<YOUR_HF_TOKEN>")` explicitly.
"""
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
# --------------------------------------------------
# ModelΒ & pipeline setup
# --------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN_NEW")
login(HF_TOKEN)
# If you prefer to hard‑code the token, uncomment:
# login("hf_your_token_here")
BASE_MODEL = "black-forest-labs/FLUX.1-dev"
CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
controlnet_single = FluxControlNetModel.from_pretrained(
CONTROLNET_MODEL, torch_dtype=dtype
)
controlnet = FluxMultiControlNetModel([controlnet_single])
pipe = FluxControlNetPipeline.from_pretrained(
BASE_MODEL, controlnet=controlnet, torch_dtype=dtype
).to(device)
pipe.set_progress_bar_config(disable=True)
# --------------------------------------------------
# UI ‑> model value mapping
# --------------------------------------------------
MODE_MAPPING = {
"canny": 0,
"tile": 1,
"depth": 2,
"blur": 3,
"pose": 4,
"gray": 5,
"low quality": 6,
}
MAX_SEED = 100
# -----------------------------------------------------------------------------
# Preview helpers – one small, self‑contained function per mode
# -----------------------------------------------------------------------------
def _preview_canny(
pil_img: Image.Image, canny_threshold_1: int, canny_threshold_2: int
) -> Image.Image:
"""Fast Canny‑edge preview (already implemented)."""
arr = np.array(pil_img.convert("RGB"))
edges = cv2.Canny(arr, threshold1=canny_threshold_1, threshold2=canny_threshold_2)
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
return Image.fromarray(edges_rgb)
# ――― tile ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
def _preview_tile(pil_img: Image.Image, grid: Tuple[int, int] = (2, 2)) -> Image.Image:
"""Replicates *pil_img* into an *nΓ—m* tiled grid (default 2Γ—2).
This offers a quick visual hint of what a *tiling* control mode will do
(repeatable textures, etc.)."""
cols, rows = grid
img_rgb = pil_img.convert("RGB")
w, h = img_rgb.size
tiled = Image.new("RGB", (w * cols, h * rows))
for c in range(cols):
for r in range(rows):
tiled.paste(img_rgb, (c * w, r * h))
return tiled
# ――― depth ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
def _preview_depth(pil_img: Image.Image) -> Image.Image:
"""Very rough *depth* proxy using the Laplacian and a colormap.
β–Έ Convert to gray
β–Έ Run Laplacian to highlight depth‑like gradients
β–Έ Apply a TURBO colormap to mimic depth heat‑map appearance"""
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
lap = cv2.Laplacian(gray, cv2.CV_16S, ksize=3)
depth = cv2.convertScaleAbs(lap)
depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO)
return Image.fromarray(depth_color)
# ――― blur ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
def _preview_blur(pil_img: Image.Image, ksize: int = 15) -> Image.Image:
"""Gaussian blur preview.
A single, relatively large kernel is enough for UI illustration."""
if ksize % 2 == 0:
ksize += 1 # kernel must be odd
blurred = cv2.GaussianBlur(np.array(pil_img), (ksize, ksize), sigmaX=0)
return Image.fromarray(blurred)
# ――― pose ―――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
def _preview_pose(pil_img: Image.Image) -> Image.Image:
"""Attempt a lightweight 2‑D pose overlay using *mediapipe* if available.
If *mediapipe* is not installed (or CPU inference fails), we gracefully
fallback to an edge‑map preview so the UI never crashes."""
try:
import mediapipe as mp # type: ignore
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
with mp_pose.Pose(static_image_mode=True) as pose_estimator:
results = pose_estimator.process(
img_bgr[..., ::-1]
) # Mediapipe expects RGB
annotated = img_bgr.copy()
if results.pose_landmarks:
mp_drawing.draw_landmarks(
annotated, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
)
annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
return Image.fromarray(annotated_rgb)
except Exception as exc: # pragma: no cover – any import / runtime error
warnings.warn(
f"Pose preview failed ({exc!s}); falling back to Canny.", RuntimeWarning
)
# Return an edge map as a sensible fallback rather than exploding the UI
return _preview_canny(pil_img, 100, 200)
# ――― gray ――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
def _preview_gray(pil_img: Image.Image) -> Image.Image:
"""Simple grayscale conversion, but keep a 3‑channel RGB image so the UI
widget pipeline stays consistent."""
gray = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2GRAY)
gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
return Image.fromarray(gray_rgb)
# ――― low quality ――――――――――――――――――――――――――――――――――――――――――――――――――――――――― #
def _preview_low_quality(pil_img: Image.Image, factor: int = 8) -> Image.Image:
"""Mimic a low‑quality thumbnail: aggressively downsample then upscale.
The default *factor* (8Γ—) is chosen to make artefacts obvious."""
img_rgb = pil_img.convert("RGB")
w, h = img_rgb.size
small = img_rgb.resize((max(1, w // factor), max(1, h // factor)), Image.BILINEAR)
low_q = small.resize(
(w, h), Image.NEAREST
) # upsample w/ Nearest to exaggerate blocks
return low_q
# -----------------------------------------------------------------------------
# Master dispatch
# -----------------------------------------------------------------------------
def _make_preview(
control_image: Image.Image,
mode: str,
canny_threshold_1: int = 100,
canny_threshold_2: int = 200,
) -> Image.Image:
"""Return a *quick‑n‑dirty* preview image for the requested *mode*.
Parameters
----------
control_image : PIL.Image
The input image selected by the user.
mode : str
One of the keys of :data:`MODE_MAPPING`.
canny_threshold_1 / 2 : int, optional
Only used if *mode* is "canny" (passed straight to OpenCV Canny).
"""
mode = mode.lower()
if mode not in MODE_MAPPING:
warnings.warn(f"Unknown preview mode '{mode}'. Returning untouched image.")
return control_image
if mode == "canny":
return _preview_canny(control_image, canny_threshold_1, canny_threshold_2)
if mode == "tile":
return _preview_tile(control_image)
if mode == "depth":
return _preview_depth(control_image)
if mode == "blur":
return _preview_blur(control_image)
if mode == "pose":
return _preview_pose(control_image)
if mode == "gray":
return _preview_gray(control_image)
if mode == "low quality":
return _preview_low_quality(control_image)
# Fallback – should never happen due to early mode check
return control_image
# --------------------------------------------------
# Inference function
# --------------------------------------------------
@spaces.GPU
def infer(
control_image: Image.Image,
prompt: str,
mode: str,
control_strength: float,
seed: int,
randomize_seed: bool,
guidance_scale: float,
num_inference_steps: int,
canny_threshold_1: int,
canny_threshold_2: int,
):
if control_image is None:
raise gr.Error("Please upload a control image first.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
gen = torch.Generator(device).manual_seed(seed)
w, h = control_image.size
preprocessed = _make_preview(
control_image, mode, canny_threshold_1, canny_threshold_2
)
result = pipe(
prompt=prompt,
control_image=[preprocessed],
control_mode=[MODE_MAPPING[mode]],
width=w,
height=h,
controlnet_conditioning_scale=[control_strength],
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=gen,
).images[0]
return result, seed, preprocessed
# --------------------------------------------------
# GradioΒ UI
# --------------------------------------------------
css = """#wrapper {max-width: 960px; margin: 0 auto;}"""
with gr.Blocks(css=css, elem_id="wrapper") as demo:
gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro by Frank")
gr.Markdown(
"A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. "
+ "Recommended strengths: *cannyΒ 0.76*. Long prompts usually help."
)
# ------------ Image panel row ------------
with gr.Row():
control_image = gr.Image(
label="Upload animage",
type="pil",
height=512 + 256,
)
result_image = gr.Image(label="Result", height=512 + 256)
preview_image = gr.Image(label="Pre‑processed Cond", height=512 + 256)
# ------------ Prompt ------------
prompt_txt = gr.Textbox(label="Prompt", value="White background", lines=1)
# ------------ ControlNet settings ------------
with gr.Row():
with gr.Column():
gr.Markdown("### ControlNet")
mode_radio = gr.Radio(
choices=list(MODE_MAPPING.keys()), value="canny", label="Mode"
)
strength_slider = gr.Slider(
0.0, 1.0, value=0.76, step=0.01, label="control strength"
)
gr.Markdown("### Preprocess")
canny_threshold_1 = gr.Slider(
0, 500, step=1, value=100, label="Canny threshold 1"
)
canny_threshold_2 = gr.Slider(
0, 500, step=1, value=200, label="Canny threshold 2"
)
with gr.Column():
seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed")
randomize_chk = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(
0.0, 10.0, step=0.1, value=3.5, label="Guidance scale"
)
steps_slider = gr.Slider(1, 50, step=1, value=50, label="Inference steps")
submit_btn = gr.Button("Submit")
submit_btn.click(
fn=infer,
inputs=[
control_image,
prompt_txt,
mode_radio,
strength_slider,
seed_slider,
randomize_chk,
guidance_slider,
steps_slider,
canny_threshold_1,
canny_threshold_2,
],
outputs=[result_image, seed_slider, preview_image],
)
if __name__ == "__main__":
demo.launch()