import gradio as gr import numpy as np import spaces import torch import random import json import os from PIL import Image from diffusers import FluxKontextPipeline from diffusers.utils import load_image from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, login from safetensors.torch import load_file import requests import re device = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEED = np.iinfo(np.int32).max pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to(device) with open("flux_loras.json", "r") as file: data = json.load(file) flux_loras_raw = [ { "image": item["image"], "title": item["title"], "repo": item["repo"], "weights": item.get("weights", "pytorch_lora_weights.safetensors"), "prompt": item.get("prompt"), "lora_type": item.get("lora_type", "flux"), "lora_scale_config": item.get("lora_scale", 1.5), } for item in data ] print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON") lora_cache = {} def load_lora_weights(repo_id, weights_filename): """Load LoRA weights from HuggingFace""" try: if repo_id not in lora_cache: lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) lora_cache[repo_id] = lora_path return lora_cache[repo_id] except Exception as e: print(f"Error loading LoRA from {repo_id}: {e}") return None def update_selection(selected_state: gr.SelectData, flux_loras): """Update UI when a LoRA is selected""" if selected_state.index >= len(flux_loras): return "### No LoRA selected", gr.update(), None, gr.update() lora_repo = flux_loras[selected_state.index]["repo"] prompt = flux_loras[selected_state.index]["prompt"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})" if prompt: new_placeholder = prompt else: new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'" print("Selected Index: ", flux_loras[selected_state.index]) optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.5) print("Optimal Scale: ", optimal_scale) return updated_text, gr.update(placeholder=new_placeholder), selected_state.index, optimal_scale def get_huggingface_lora(link): """Download LoRA from HuggingFace link""" split_link = link.split("/") if len(split_link) == 2: try: model_card = ModelCard.load(link) trigger_word = model_card.data.get("instance_prompt", "") fs = HfFileSystem() list_of_files = fs.ls(link, detail=False) safetensors_file = None for file in list_of_files: if file.endswith(".safetensors") and "lora" in file.lower(): safetensors_file = file.split("/")[-1] break if not safetensors_file: safetensors_file = "pytorch_lora_weights.safetensors" return split_link[1], safetensors_file, trigger_word except Exception as e: raise Exception(f"Error loading LoRA: {e}") else: raise Exception("Invalid HuggingFace repository format") def classify_gallery(flux_loras): """Sort gallery by likes""" sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True) return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery def infer_with_lora_wrapper( input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75, flux_loras=None, ): """Wrapper function to handle state serialization""" return infer_with_lora(input_image, prompt, selected_index, seed, randomize_seed, guidance_scale, lora_scale, flux_loras) @spaces.GPU def infer_with_lora( input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, ): """Generate image with selected LoRA""" global pipe if randomize_seed: seed = random.randint(0, MAX_SEED) # Determine which LoRA to use lora_to_use = None if selected_index is not None and flux_loras and selected_index < len(flux_loras): lora_to_use = flux_loras[selected_index] print(f"Loaded {len(flux_loras)} LoRAs from JSON") # Load LoRA if needed print(f"LoRA to use: {lora_to_use}") if lora_to_use: try: if "selected_lora" in pipe.get_active_adapters(): pipe.unload_lora_weights() lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"]) if lora_path: pipe.load_lora_weights(lora_path, adapter_name="selected_lora") pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale]) print(f"loaded: {lora_path} with scale {lora_scale}") except Exception as e: print(f"Error loading LoRA: {e}") input_image = input_image.convert("RGB") prompt = lora_to_use["prompt"] try: image = pipe(image=input_image, width=input_image.size[0], height=input_image.size[1], prompt=prompt, guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(seed)).images[0] return image, seed, gr.update(visible=True), lora_scale except Exception as e: print(f"Error during inference: {e}") return None, seed, gr.update(visible=False), lora_scale # CSS styling css = """ #main_app { display: flex; gap: 20px; } #box_column { min-width: 400px; } #title{text-align: center} #title h1{font-size: 3em; display:inline-flex; align-items:center} #title img{width: 100px; margin-right: 0.5em} #selected_lora { color: #2563eb; font-weight: bold; } #prompt { flex-grow: 1; } #run_button { background: linear-gradient(45deg, #2563eb, #3b82f6); color: white; border: none; padding: 8px 16px; border-radius: 6px; font-weight: bold; } .custom_lora_card { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; margin: 8px 0; } #gallery{ overflow: scroll !important } """ # Create Gradio interface with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo: gr_flux_loras = gr.State(value=flux_loras_raw) title = gr.HTML( """

LoRA FLUX.1 Kontext for Segmentation

""", elem_id="title", ) selected_state = gr.State(value=None) lora_state = gr.State(value=1.0) with gr.Row(elem_id="main_app"): with gr.Column(scale=4, elem_id="box_column"): with gr.Group(elem_id="gallery_box"): input_image = gr.Image(label="Upload an image", type="pil", height=300) gallery = gr.Gallery(label="Pick a LoRA", allow_preview=False, columns=3, elem_id="gallery", show_share_button=False, height=400) with gr.Column(scale=5): with gr.Row(): prompt = gr.Textbox( label="Editing Prompt", show_label=False, lines=1, max_lines=1, placeholder="", elem_id="prompt", interactive=False, ) run_button = gr.Button("Generate", elem_id="run_button") result = gr.Image(label="Generated Image", interactive=False) reuse_button = gr.Button("Reuse this image", visible=False) with gr.Accordion("Advanced Settings", open=False): lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.1, value=1.5, info="Controls the strength of the LoRA effect") seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1, maximum=10, step=0.1, value=2.5, ) prompt_title = gr.Markdown( value="### Click on a LoRA in the gallery to select it", visible=True, elem_id="selected_lora", ) gallery.select(fn=update_selection, inputs=[gr_flux_loras], outputs=[prompt_title, prompt, selected_state, lora_scale], show_progress=False) gr.on( triggers=[run_button.click, prompt.submit], fn=infer_with_lora_wrapper, inputs=[input_image, prompt, selected_state, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras], outputs=[result, seed, reuse_button, lora_state], ) reuse_button.click(fn=lambda image: image, inputs=[result], outputs=[input_image]) # Initialize gallery demo.load(fn=classify_gallery, inputs=[gr_flux_loras], outputs=[gallery, gr_flux_loras]) demo.queue(default_concurrency_limit=None) demo.launch()