import gradio as gr import numpy as np import spaces import torch import random import json import os from PIL import Image from diffusers import FluxKontextPipeline from diffusers.utils import load_image from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, login from safetensors.torch import load_file import requests import re device = "cuda" if torch.cuda.is_available() else "cpu" MAX_SEED = np.iinfo(np.int32).max pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to(device) with open("flux_loras.json", "r") as file: data = json.load(file) flux_loras_raw = [ { "image": item["image"], "title": item["title"], "repo": item["repo"], "weights": item.get("weights", "pytorch_lora_weights.safetensors"), "prompt": item.get("prompt"), "lora_type": item.get("lora_type", "flux"), "lora_scale_config": item.get("lora_scale", 1.5), } for item in data ] print(f"Loaded {len(flux_loras_raw)} LoRAs from JSON") lora_cache = {} def load_lora_weights(repo_id, weights_filename): """Load LoRA weights from HuggingFace""" try: if repo_id not in lora_cache: lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) lora_cache[repo_id] = lora_path return lora_cache[repo_id] except Exception as e: print(f"Error loading LoRA from {repo_id}: {e}") return None def update_selection(selected_state: gr.SelectData, flux_loras): """Update UI when a LoRA is selected""" if selected_state.index >= len(flux_loras): return "### No LoRA selected", gr.update(), None, gr.update() lora_repo = flux_loras[selected_state.index]["repo"] prompt = flux_loras[selected_state.index]["prompt"] updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})" if prompt: new_placeholder = prompt else: new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'" print("Selected Index: ", flux_loras[selected_state.index]) optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.5) print("Optimal Scale: ", optimal_scale) return updated_text, gr.update(placeholder=new_placeholder), selected_state.index, optimal_scale def get_huggingface_lora(link): """Download LoRA from HuggingFace link""" split_link = link.split("/") if len(split_link) == 2: try: model_card = ModelCard.load(link) trigger_word = model_card.data.get("instance_prompt", "") fs = HfFileSystem() list_of_files = fs.ls(link, detail=False) safetensors_file = None for file in list_of_files: if file.endswith(".safetensors") and "lora" in file.lower(): safetensors_file = file.split("/")[-1] break if not safetensors_file: safetensors_file = "pytorch_lora_weights.safetensors" return split_link[1], safetensors_file, trigger_word except Exception as e: raise Exception(f"Error loading LoRA: {e}") else: raise Exception("Invalid HuggingFace repository format") def classify_gallery(flux_loras): """Sort gallery by likes""" sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True) return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery def infer_with_lora_wrapper( input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75, flux_loras=None, ): """Wrapper function to handle state serialization""" return infer_with_lora(input_image, prompt, selected_index, seed, randomize_seed, guidance_scale, lora_scale, flux_loras) @spaces.GPU def infer_with_lora( input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, ): """Generate image with selected LoRA""" global pipe if randomize_seed: seed = random.randint(0, MAX_SEED) # Determine which LoRA to use lora_to_use = None if selected_index is not None and flux_loras and selected_index < len(flux_loras): lora_to_use = flux_loras[selected_index] print(f"Loaded {len(flux_loras)} LoRAs from JSON") # Load LoRA if needed print(f"LoRA to use: {lora_to_use}") if lora_to_use: try: if "selected_lora" in pipe.get_active_adapters(): pipe.unload_lora_weights() lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"]) if lora_path: pipe.load_lora_weights(lora_path, adapter_name="selected_lora") pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale]) print(f"loaded: {lora_path} with scale {lora_scale}") except Exception as e: print(f"Error loading LoRA: {e}") input_image = input_image.convert("RGB") prompt = lora_to_use["prompt"] try: image = pipe(image=input_image, width=input_image.size[0], height=input_image.size[1], prompt=prompt, guidance_scale=guidance_scale, generator=torch.Generator().manual_seed(seed)).images[0] return image, seed, gr.update(visible=True), lora_scale except Exception as e: print(f"Error during inference: {e}") return None, seed, gr.update(visible=False), lora_scale # CSS styling css = """ #main_app { display: flex; gap: 20px; } #box_column { min-width: 400px; } #title{text-align: center} #title h1{font-size: 3em; display:inline-flex; align-items:center} #title img{width: 100px; margin-right: 0.5em} #selected_lora { color: #2563eb; font-weight: bold; } #prompt { flex-grow: 1; } #run_button { background: linear-gradient(45deg, #2563eb, #3b82f6); color: white; border: none; padding: 8px 16px; border-radius: 6px; font-weight: bold; } .custom_lora_card { background: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 12px; margin: 8px 0; } #gallery{ overflow: scroll !important } """ # Create Gradio interface with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo: gr_flux_loras = gr.State(value=flux_loras_raw) title = gr.HTML( """
FLUX.1 Kontext for Segmentation