Spaces:
Running
Running
import numpy as np | |
import gradio as gr | |
import spaces | |
import os | |
import random | |
import torch | |
from PIL import Image | |
import cv2 | |
from huggingface_hub import login | |
from diffusers import FluxControlNetPipeline, FluxControlNetModel | |
from diffusers.models import FluxMultiControlNetModel | |
""" | |
FLUX‑1 ControlNet demo | |
---------------------- | |
This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload | |
slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model. | |
Key points | |
~~~~~~~~~~ | |
* Single *control image* input (left). | |
* *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right). | |
* *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength. | |
* Seed handling with optional randomisation. | |
* Advanced sliders for *Guidance scale* and *Inference steps*. | |
* Works on CUDA (bfloat16) or CPU (float32). | |
* Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the | |
other modes). | |
Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call | |
`login("<YOUR_HF_TOKEN>")` explicitly. | |
""" | |
# -------------------------------------------------- | |
# Model & pipeline setup | |
# -------------------------------------------------- | |
HF_TOKEN = os.getenv("HF_TOKEN_NEW") | |
login(HF_TOKEN) | |
# If you prefer to hard‑code the token, uncomment: | |
# login("hf_your_token_here") | |
BASE_MODEL = "black-forest-labs/FLUX.1-dev" | |
CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
controlnet_single = FluxControlNetModel.from_pretrained( | |
CONTROLNET_MODEL, torch_dtype=dtype | |
) | |
controlnet = FluxMultiControlNetModel([controlnet_single]) | |
pipe = FluxControlNetPipeline.from_pretrained( | |
BASE_MODEL, controlnet=controlnet, torch_dtype=dtype | |
).to(device) | |
pipe.set_progress_bar_config(disable=True) | |
# -------------------------------------------------- | |
# UI ‑> model value mapping | |
# -------------------------------------------------- | |
MODE_MAPPING = { | |
"canny": 0, | |
"depth": 1, | |
"openpose": 2, | |
"gray": 3, | |
"blur": 4, | |
"tile": 5, | |
"low quality": 6, | |
} | |
MAX_SEED = 100 | |
# -------------------------------------------------- | |
# Helper: quick‑n‑dirty Canny preview (only for UI display) | |
# -------------------------------------------------- | |
def _preview_canny(pil_img: Image.Image) -> Image.Image: | |
arr = np.array(pil_img.convert("RGB")) | |
edges = cv2.Canny(arr, 100, 200) | |
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
return Image.fromarray(edges_rgb) | |
def _make_preview(control_image: Image.Image, mode: str) -> Image.Image: | |
if mode == "canny": | |
return _preview_canny(control_image) | |
# For other modes you can plug in your own visualiser later | |
return control_image | |
# -------------------------------------------------- | |
# Inference function | |
# -------------------------------------------------- | |
def infer( | |
control_image: Image.Image, | |
prompt: str, | |
mode: str, | |
control_strength: float, | |
seed: int, | |
randomize_seed: bool, | |
guidance_scale: float, | |
num_inference_steps: int, | |
): | |
if control_image is None: | |
raise gr.Error("Please upload a control image first.") | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
gen = torch.Generator(device).manual_seed(seed) | |
w, h = control_image.size | |
result = pipe( | |
prompt=prompt, | |
control_image=[control_image], | |
control_mode=[MODE_MAPPING[mode]], | |
width=w, | |
height=h, | |
controlnet_conditioning_scale=[control_strength], | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=gen, | |
).images[0] | |
preview = _make_preview(control_image, mode) | |
return result, seed, preview | |
# -------------------------------------------------- | |
# Gradio UI | |
# -------------------------------------------------- | |
css = """#wrapper {max-width: 960px; margin: 0 auto;}""" | |
with gr.Blocks(css=css, elem_id="wrapper") as demo: | |
gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro") | |
gr.Markdown( | |
"A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. " | |
+ "Recommended strengths: *canny 0.65*, *tile 0.45*, *depth 0.55*, *blur 0.45*, " | |
+ "*openpose 0.55*, *gray 0.45*, *low quality 0.40*. Long prompts usually help." | |
) | |
# ------------ Image panel row ------------ | |
with gr.Row(): | |
control_image = gr.Image( | |
label="Upload a processed control image", | |
type="pil", | |
height=512, | |
) | |
result_image = gr.Image(label="Result", height=512) | |
preview_image = gr.Image(label="Pre‑processed Cond", height=512) | |
# ------------ Prompt ------------ | |
prompt_txt = gr.Textbox(label="Prompt", value="best quality", lines=1) | |
# ------------ ControlNet settings ------------ | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### ControlNet") | |
mode_radio = gr.Radio( | |
choices=list(MODE_MAPPING.keys()), value="gray", label="Mode" | |
) | |
strength_slider = gr.Slider( | |
0.0, 1.0, value=0.5, step=0.01, label="control strength" | |
) | |
with gr.Column(): | |
seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed") | |
randomize_chk = gr.Checkbox(label="Randomize seed", value=True) | |
guidance_slider = gr.Slider( | |
0.0, 10.0, step=0.1, value=3.5, label="Guidance scale" | |
) | |
steps_slider = gr.Slider(1, 50, step=1, value=24, label="Inference steps") | |
submit_btn = gr.Button("Submit") | |
submit_btn.click( | |
fn=infer, | |
inputs=[ | |
control_image, | |
prompt_txt, | |
mode_radio, | |
strength_slider, | |
seed_slider, | |
randomize_chk, | |
guidance_slider, | |
steps_slider, | |
], | |
outputs=[result_image, seed_slider, preview_image], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |