Spaces:
Running
Running
import numpy as np | |
import gradio as gr | |
import spaces | |
import os | |
import random | |
import subprocess | |
import torch | |
from PIL import Image | |
import cv2 | |
from huggingface_hub import login | |
from diffusers import FluxControlNetPipeline, FluxControlNetModel | |
from diffusers.models import FluxMultiControlNetModel | |
import warnings | |
from typing import Tuple | |
""" | |
FLUXβ1 ControlNet demo | |
---------------------- | |
This script rebuilds the Gradio interface shown in your screenshot with **one** controlβimage upload | |
slot and integrates the FLUX.1βdevβControlNetβUnionβPro model. | |
Key points | |
~~~~~~~~~~ | |
* Single *control image* input (left). | |
* *Result* and *Preβprocessed Cond* previews sideβbyβside (center & right). | |
* *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength. | |
* Seed handling with optional randomisation. | |
* Advanced sliders for *Guidance scale* and *Inference steps*. | |
* Works on CUDA (bfloat16) or CPU (float32). | |
* Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the | |
other modes). | |
Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call | |
`login("<YOUR_HF_TOKEN>")` explicitly. | |
""" | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
# -------------------------------------------------- | |
# ModelΒ & pipeline setup | |
# -------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN_NEW") | |
login(HF_TOKEN) | |
# If you prefer to hardβcode the token, uncomment: | |
# login("hf_your_token_here") | |
BASE_MODEL = "black-forest-labs/FLUX.1-dev" | |
CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
controlnet_single = FluxControlNetModel.from_pretrained( | |
CONTROLNET_MODEL, torch_dtype=dtype | |
) | |
controlnet = FluxMultiControlNetModel([controlnet_single]) | |
pipe = FluxControlNetPipeline.from_pretrained( | |
BASE_MODEL, controlnet=controlnet, torch_dtype=dtype | |
).to(device) | |
pipe.set_progress_bar_config(disable=True) | |
# -------------------------------------------------- | |
# UIΒ β> model value mapping | |
# -------------------------------------------------- | |
MODE_MAPPING = { | |
"canny": 0, | |
"tile": 1, | |
"depth": 2, | |
"blur": 3, | |
"pose": 4, | |
"gray": 5, | |
"low quality": 6, | |
} | |
MAX_SEED = 100 | |
# ----------------------------------------------------------------------------- | |
# Preview helpers βΒ one small, selfβcontained function per mode | |
# ----------------------------------------------------------------------------- | |
def _preview_canny( | |
pil_img: Image.Image, canny_threshold_1: int, canny_threshold_2: int | |
) -> Image.Image: | |
"""Fast Cannyβedge preview (already implemented).""" | |
arr = np.array(pil_img.convert("RGB")) | |
edges = cv2.Canny(arr, threshold1=canny_threshold_1, threshold2=canny_threshold_2) | |
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
return Image.fromarray(edges_rgb) | |
# βββ tile ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
def _preview_tile(pil_img: Image.Image, grid: Tuple[int, int] = (2, 2)) -> Image.Image: | |
"""Replicates *pil_img* into an *nΓm* tiled grid (default 2Γ2). | |
This offers a quick visual hint of what a *tiling* control mode will do | |
(repeatable textures, etc.).""" | |
cols, rows = grid | |
img_rgb = pil_img.convert("RGB") | |
w, h = img_rgb.size | |
tiled = Image.new("RGB", (w * cols, h * rows)) | |
for c in range(cols): | |
for r in range(rows): | |
tiled.paste(img_rgb, (c * w, r * h)) | |
return tiled | |
# βββ depth βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
def _preview_depth(pil_img: Image.Image) -> Image.Image: | |
"""Very rough *depth* proxy using the Laplacian and a colormap. | |
βΈ Convert to gray | |
βΈ Run Laplacian to highlight depthβlike gradients | |
βΈ Apply a TURBO colormap to mimic depth heatβmap appearance""" | |
gray = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) | |
lap = cv2.Laplacian(gray, cv2.CV_16S, ksize=3) | |
depth = cv2.convertScaleAbs(lap) | |
depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_TURBO) | |
return Image.fromarray(depth_color) | |
# βββ blur βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
def _preview_blur(pil_img: Image.Image, ksize: int = 15) -> Image.Image: | |
"""Gaussian blur preview. | |
A single, relatively large kernel is enough for UI illustration.""" | |
if ksize % 2 == 0: | |
ksize += 1 # kernel must be odd | |
blurred = cv2.GaussianBlur(np.array(pil_img), (ksize, ksize), sigmaX=0) | |
return Image.fromarray(blurred) | |
# βββ pose ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
def _preview_pose(pil_img: Image.Image) -> Image.Image: | |
"""Attempt a lightweight 2βD pose overlay using *mediapipe* if available. | |
If *mediapipe* is not installed (or CPU inference fails), we gracefully | |
fallback to an edgeβmap preview so the UI never crashes.""" | |
try: | |
import mediapipe as mp # type: ignore | |
mp_pose = mp.solutions.pose | |
mp_drawing = mp.solutions.drawing_utils | |
img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
with mp_pose.Pose(static_image_mode=True) as pose_estimator: | |
results = pose_estimator.process( | |
img_bgr[..., ::-1] | |
) # Mediapipe expects RGB | |
annotated = img_bgr.copy() | |
if results.pose_landmarks: | |
mp_drawing.draw_landmarks( | |
annotated, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
) | |
annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) | |
return Image.fromarray(annotated_rgb) | |
except Exception as exc: # pragma: no cover βΒ any import / runtime error | |
warnings.warn( | |
f"Pose preview failed ({exc!s}); falling back to Canny.", RuntimeWarning | |
) | |
# Return an edge map as a sensible fallback rather than exploding the UI | |
return _preview_canny(pil_img, 100, 200) | |
# βββ gray βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
def _preview_gray(pil_img: Image.Image) -> Image.Image: | |
"""Simple grayscale conversion, but keep a 3βchannel RGB image so the UI | |
widget pipeline stays consistent.""" | |
gray = cv2.cvtColor(np.array(pil_img.convert("RGB")), cv2.COLOR_RGB2GRAY) | |
gray_rgb = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) | |
return Image.fromarray(gray_rgb) | |
# βββ low quality βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ # | |
def _preview_low_quality(pil_img: Image.Image, factor: int = 8) -> Image.Image: | |
"""Mimic a lowβquality thumbnail: aggressively downsample then upscale. | |
The default *factor* (8Γ) is chosen to make artefacts obvious.""" | |
img_rgb = pil_img.convert("RGB") | |
w, h = img_rgb.size | |
small = img_rgb.resize((max(1, w // factor), max(1, h // factor)), Image.BILINEAR) | |
low_q = small.resize( | |
(w, h), Image.NEAREST | |
) # upsample w/ Nearest to exaggerate blocks | |
return low_q | |
# ----------------------------------------------------------------------------- | |
# Master dispatch | |
# ----------------------------------------------------------------------------- | |
def _make_preview( | |
control_image: Image.Image, | |
mode: str, | |
canny_threshold_1: int = 100, | |
canny_threshold_2: int = 200, | |
) -> Image.Image: | |
"""Return a *quickβnβdirty* preview image for the requested *mode*. | |
Parameters | |
---------- | |
control_image : PIL.Image | |
The input image selected by the user. | |
mode : str | |
One of the keys of :data:`MODE_MAPPING`. | |
canny_threshold_1 / 2 : int, optional | |
Only used if *mode* is "canny" (passed straight to OpenCV Canny). | |
""" | |
mode = mode.lower() | |
if mode not in MODE_MAPPING: | |
warnings.warn(f"Unknown preview mode '{mode}'. Returning untouched image.") | |
return control_image | |
if mode == "canny": | |
return _preview_canny(control_image, canny_threshold_1, canny_threshold_2) | |
if mode == "tile": | |
return _preview_tile(control_image) | |
if mode == "depth": | |
return _preview_depth(control_image) | |
if mode == "blur": | |
return _preview_blur(control_image) | |
if mode == "pose": | |
return _preview_pose(control_image) | |
if mode == "gray": | |
return _preview_gray(control_image) | |
if mode == "low quality": | |
return _preview_low_quality(control_image) | |
# Fallback βΒ should never happen due to early mode check | |
return control_image | |
# -------------------------------------------------- | |
# Inference function | |
# -------------------------------------------------- | |
def infer( | |
control_image: Image.Image, | |
prompt: str, | |
mode: str, | |
control_strength: float, | |
seed: int, | |
randomize_seed: bool, | |
guidance_scale: float, | |
num_inference_steps: int, | |
canny_threshold_1: int, | |
canny_threshold_2: int, | |
): | |
if control_image is None: | |
raise gr.Error("Please upload a control image first.") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
gen = torch.Generator(device).manual_seed(seed) | |
w, h = control_image.size | |
preprocessed = _make_preview( | |
control_image, mode, canny_threshold_1, canny_threshold_2 | |
) | |
result = pipe( | |
prompt=prompt, | |
control_image=[preprocessed], | |
control_mode=[MODE_MAPPING[mode]], | |
width=w, | |
height=h, | |
controlnet_conditioning_scale=[control_strength], | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=gen, | |
).images[0] | |
return result, seed, preprocessed | |
# -------------------------------------------------- | |
# GradioΒ UI | |
# -------------------------------------------------- | |
css = """#wrapper {max-width: 960px; margin: 0 auto;}""" | |
with gr.Blocks(css=css, elem_id="wrapper") as demo: | |
gr.Markdown("## FLUX.1βdevβControlNetβUnionβPro by Frank") | |
gr.Markdown( | |
"A unified ControlNet for **FLUX.1βdev** from the InstantX team and Shakker Labs. " | |
+ "Recommended strengths: *cannyΒ 0.76*. Long prompts usually help." | |
) | |
# ------------ Image panel row ------------ | |
with gr.Row(): | |
control_image = gr.Image( | |
label="Upload animage", | |
type="pil", | |
height=512 + 256, | |
) | |
result_image = gr.Image(label="Result", height=512 + 256) | |
preview_image = gr.Image(label="Preβprocessed Cond", height=512 + 256) | |
# ------------ Prompt ------------ | |
prompt_txt = gr.Textbox(label="Prompt", value="White background", lines=1) | |
# ------------ ControlNet settings ------------ | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### ControlNet") | |
mode_radio = gr.Radio( | |
choices=list(MODE_MAPPING.keys()), value="canny", label="Mode" | |
) | |
strength_slider = gr.Slider( | |
0.0, 1.0, value=0.76, step=0.01, label="control strength" | |
) | |
gr.Markdown("### Preprocess") | |
canny_threshold_1 = gr.Slider( | |
0, 500, step=1, value=100, label="Canny threshold 1" | |
) | |
canny_threshold_2 = gr.Slider( | |
0, 500, step=1, value=200, label="Canny threshold 2" | |
) | |
with gr.Column(): | |
seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed") | |
randomize_chk = gr.Checkbox(label="Randomize seed", value=False) | |
guidance_slider = gr.Slider( | |
0.0, 10.0, step=0.1, value=3.5, label="Guidance scale" | |
) | |
steps_slider = gr.Slider(1, 50, step=1, value=50, label="Inference steps") | |
submit_btn = gr.Button("Submit") | |
submit_btn.click( | |
fn=infer, | |
inputs=[ | |
control_image, | |
prompt_txt, | |
mode_radio, | |
strength_slider, | |
seed_slider, | |
randomize_chk, | |
guidance_slider, | |
steps_slider, | |
canny_threshold_1, | |
canny_threshold_2, | |
], | |
outputs=[result_image, seed_slider, preview_image], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |