Jannik Rößler
Update space
835e1fb
import gradio as gr
import numpy as np
import spaces
import torch
import random
import json
import os
from PIL import Image
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, login
from safetensors.torch import load_file
import requests
import re
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to(device)
with open("flux_loras.json", "r") as file:
data = json.load(file)
flux_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"weights": item.get("weights", "pytorch_lora_weights.safetensors"),
"prompt": item.get("prompt"),
"lora_type": item.get("lora_type", "flux"),
"lora_scale_config": item.get("lora_scale", 1.5),
}
for item in data
]
print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON")
lora_cache = {}
def load_lora_weights(repo_id, weights_filename):
"""Load LoRA weights from HuggingFace"""
try:
if repo_id not in lora_cache:
lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
lora_cache[repo_id] = lora_path
return lora_cache[repo_id]
except Exception as e:
print(f"Error loading LoRA from {repo_id}: {e}")
return None
def update_selection(selected_state: gr.SelectData, flux_loras):
"""Update UI when a LoRA is selected"""
if selected_state.index >= len(flux_loras):
return "### No LoRA selected", gr.update(), None, gr.update()
lora_repo = flux_loras[selected_state.index]["repo"]
prompt = flux_loras[selected_state.index]["prompt"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
if prompt:
new_placeholder = prompt
else:
new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
print("Selected Index: ", flux_loras[selected_state.index])
optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.5)
print("Optimal Scale: ", optimal_scale)
return updated_text, gr.update(placeholder=new_placeholder), selected_state.index, optimal_scale
def get_huggingface_lora(link):
"""Download LoRA from HuggingFace link"""
split_link = link.split("/")
if len(split_link) == 2:
try:
model_card = ModelCard.load(link)
trigger_word = model_card.data.get("instance_prompt", "")
fs = HfFileSystem()
list_of_files = fs.ls(link, detail=False)
safetensors_file = None
for file in list_of_files:
if file.endswith(".safetensors") and "lora" in file.lower():
safetensors_file = file.split("/")[-1]
break
if not safetensors_file:
safetensors_file = "pytorch_lora_weights.safetensors"
return split_link[1], safetensors_file, trigger_word
except Exception as e:
raise Exception(f"Error loading LoRA: {e}")
else:
raise Exception("Invalid HuggingFace repository format")
def classify_gallery(flux_loras):
"""Sort gallery by likes"""
sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
def infer_with_lora_wrapper(
input_image,
prompt,
selected_index,
seed=42,
randomize_seed=False,
guidance_scale=2.5,
lora_scale=1.75,
flux_loras=None,
):
"""Wrapper function to handle state serialization"""
return infer_with_lora(input_image, prompt, selected_index, seed, randomize_seed, guidance_scale, lora_scale, flux_loras)
@spaces.GPU
def infer_with_lora(
input_image,
prompt,
selected_index,
seed=42,
randomize_seed=False,
guidance_scale=2.5,
lora_scale=1.0,
flux_loras=None,
):
"""Generate image with selected LoRA"""
global pipe
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Determine which LoRA to use
lora_to_use = None
if selected_index is not None and flux_loras and selected_index < len(flux_loras):
lora_to_use = flux_loras[selected_index]
print(f"Loaded {len(flux_loras)} LoRAs from JSON")
# Load LoRA if needed
print(f"LoRA to use: {lora_to_use}")
if lora_to_use:
try:
if "selected_lora" in pipe.get_active_adapters():
pipe.unload_lora_weights()
lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
if lora_path:
pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
print(f"loaded: {lora_path} with scale {lora_scale}")
except Exception as e:
print(f"Error loading LoRA: {e}")
input_image = input_image.convert("RGB")
prompt = lora_to_use["prompt"]
try:
image = pipe(image=input_image, width=input_image.size[0], height=input_image.size[1], prompt=prompt, guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(seed)).images[0]
return image, seed, gr.update(visible=True), lora_scale
except Exception as e:
print(f"Error during inference: {e}")
return None, seed, gr.update(visible=False), lora_scale
# CSS styling
css = """
#main_app {
display: flex;
gap: 20px;
}
#box_column {
min-width: 400px;
}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#selected_lora {
color: #2563eb;
font-weight: bold;
}
#prompt {
flex-grow: 1;
}
#run_button {
background: linear-gradient(45deg, #2563eb, #3b82f6);
color: white;
border: none;
padding: 8px 16px;
border-radius: 6px;
font-weight: bold;
}
.custom_lora_card {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 8px;
padding: 12px;
margin: 8px 0;
}
#gallery{
overflow: scroll !important
}
"""
# Create Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo:
gr_flux_loras = gr.State(value=flux_loras_raw)
title = gr.HTML(
"""<h1><img src="https://huggingface.co/jroessler/flux-kontext-segmentation-sweatshirt/resolve/main/t-shirt-emoji.png" alt="LoRA"> FLUX.1 Kontext for Segmentation</h1>""",
elem_id="title",
)
selected_state = gr.State(value=None)
lora_state = gr.State(value=1.0)
with gr.Row(elem_id="main_app"):
with gr.Column(scale=4, elem_id="box_column"):
with gr.Group(elem_id="gallery_box"):
input_image = gr.Image(label="Upload an image", type="pil", height=300)
gallery = gr.Gallery(label="Pick a LoRA", allow_preview=False, columns=3, elem_id="gallery", show_share_button=False, height=400)
with gr.Column(scale=5):
with gr.Row():
prompt = gr.Textbox(
label="Editing Prompt",
show_label=False,
lines=1,
max_lines=1,
placeholder="",
elem_id="prompt",
interactive=False,
)
run_button = gr.Button("Generate", elem_id="run_button")
result = gr.Image(label="Generated Image", interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
with gr.Accordion("Advanced Settings", open=False):
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.1, value=1.5, info="Controls the strength of the LoRA effect")
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.1,
value=2.5,
)
prompt_title = gr.Markdown(
value="### Click on a LoRA in the gallery to select it",
visible=True,
elem_id="selected_lora",
)
gallery.select(fn=update_selection, inputs=[gr_flux_loras], outputs=[prompt_title, prompt, selected_state, lora_scale], show_progress=False)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer_with_lora_wrapper,
inputs=[input_image, prompt, selected_state, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
outputs=[result, seed, reuse_button, lora_state],
)
reuse_button.click(fn=lambda image: image, inputs=[result], outputs=[input_image])
# Initialize gallery
demo.load(fn=classify_gallery, inputs=[gr_flux_loras], outputs=[gallery, gr_flux_loras])
demo.queue(default_concurrency_limit=None)
demo.launch()