import os import re import json import random from functools import lru_cache from typing import List, Tuple, Optional, Any import gradio as gr from huggingface_hub import InferenceClient, hf_hub_download # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- # LoRAs in the "Kontext Dev LoRAs" collection. # NOTE: We hard-code the list for now. If the collection grows you can simply # append new model IDs here. LORA_MODELS: List[str] = [ # fal – original author "fal/Watercolor-Art-Kontext-Dev-LoRA", "fal/Pop-Art-Kontext-Dev-LoRA", "fal/Pencil-Drawing-Kontext-Dev-LoRA", "fal/Mosaic-Art-Kontext-Dev-LoRA", "fal/Minimalist-Art-Kontext-Dev-LoRA", "fal/Impressionist-Art-Kontext-Dev-LoRA", "fal/Gouache-Art-Kontext-Dev-LoRA", "fal/Expressive-Art-Kontext-Dev-LoRA", "fal/Cubist-Art-Kontext-Dev-LoRA", "fal/Collage-Art-Kontext-Dev-LoRA", "fal/Charcoal-Art-Kontext-Dev-LoRA", "fal/Acrylic-Art-Kontext-Dev-LoRA", "fal/Abstract-Art-Kontext-Dev-LoRA", "fal/Plushie-Kontext-Dev-LoRA", "fal/Youtube-Thumbnails-Kontext-Dev-LoRA", "fal/Broccoli-Hair-Kontext-Dev-LoRA", "fal/Wojak-Kontext-Dev-LoRA", "fal/3D-Game-Assets-Kontext-Dev-LoRA", "fal/Realism-Detailer-Kontext-Dev-LoRA", # community LoRAs "gokaygokay/Pencil-Drawing-Kontext-Dev-LoRA", "gokaygokay/Oil-Paint-Kontext-Dev-LoRA", "gokaygokay/Watercolor-Kontext-Dev-LoRA", "gokaygokay/Pastel-Flux-Kontext-Dev-LoRA", "gokaygokay/Low-Poly-Kontext-Dev-LoRA", "gokaygokay/Bronze-Sculpture-Kontext-Dev-LoRA", "gokaygokay/Marble-Sculpture-Kontext-Dev-LoRA", "gokaygokay/Light-Fix-Kontext-Dev-LoRA", "gokaygokay/Fuse-it-Kontext-Dev-LoRA", "ilkerzgi/Overlay-Kontext-Dev-LoRA", ] # Optional metadata cache file. Generated by `generate_lora_metadata.py`. METADATA_FILE = "lora_metadata.json" def _load_metadata() -> dict: """Load cached preview/trigger data if the JSON file exists.""" if os.path.exists(METADATA_FILE): try: with open(METADATA_FILE, "r", encoding="utf-8") as fp: return json.load(fp) except Exception: pass return {} # Token used for anonymous free quota FREE_TOKEN_ENV = "HF_TOKEN" FREE_REQUESTS = 10 # ----------------------------------------------------------------------------- # Utility helpers # ----------------------------------------------------------------------------- @lru_cache(maxsize=None) def get_client(token: str) -> InferenceClient: """Return cached InferenceClient instance for supplied token.""" return InferenceClient(provider="fal-ai", api_key=token) IMG_PATTERN = re.compile(r"!\[.*?\]\((.*?)\)") TRIGGER_PATTERN = re.compile(r"[Tt]rigger[^:]*:\s*([^\n]+)") @lru_cache(maxsize=None) def fetch_preview_and_trigger(model_id: str) -> Tuple[Optional[str], Optional[str]]: """Try to fetch a preview image URL and trigger phrase from the model card. If unsuccessful, returns (None, None). """ try: # Download README. readme_path = hf_hub_download(repo_id=model_id, filename="README.md") except Exception: return None, None image_url: Optional[str] = None trigger_phrase: Optional[str] = None try: with open(readme_path, "r", encoding="utf-8") as fp: text = fp.read() # First image in markdown → preview if (m := IMG_PATTERN.search(text)) is not None: img_path = m.group(1) if img_path.startswith("http"): image_url = img_path else: image_url = f"https://huggingface.co/{model_id}/resolve/main/{img_path.lstrip('./')}" # Try to parse trigger phrase if (m := TRIGGER_PATTERN.search(text)) is not None: trigger_phrase = m.group(1).strip() except Exception: pass return image_url, trigger_phrase # ----------------------------------------------------------------------------- # Core inference function # ----------------------------------------------------------------------------- def run_lora( input_image, # bytes or PIL.Image prompt: str, model_id: str, guidance_scale: float, token: str | None, req_count: int, ): """Execute image → image generation via selected LoRA.""" if input_image is None: raise gr.Error("Please provide an input image.") # Determine which token we will use if token: api_token = token else: free_token = os.getenv(FREE_TOKEN_ENV) if free_token is None: raise gr.Error("Service not configured for free usage. Please login.") if req_count >= FREE_REQUESTS: raise gr.Error("Free quota exceeded – please login with your own HF account to continue.") api_token = free_token client = get_client(api_token) # Gradio delivers PIL.Image by default. InferenceClient accepts bytes. if hasattr(input_image, "tobytes"): import io buf = io.BytesIO() input_image.save(buf, format="PNG") img_bytes = buf.getvalue() elif isinstance(input_image, bytes): img_bytes = input_image else: raise gr.Error("Unsupported image format.") output = client.image_to_image( img_bytes, prompt=prompt, model=model_id, guidance_scale=guidance_scale, ) # Update request count only if using free token new_count = req_count if token else req_count + 1 return output, new_count, f"Free requests remaining: {max(0, FREE_REQUESTS - new_count)}" if not token else "Logged in ✅ Unlimited" # ----------------------------------------------------------------------------- # UI assembly # ----------------------------------------------------------------------------- def build_interface(): # Pre-load metadata into closure for fast look-ups. metadata_cache = _load_metadata() # Theme & CSS theme = gr.themes.Soft(primary_hue="violet", secondary_hue="indigo") custom_css = """ .gradio-container {max-width: 980px; margin: auto;} .gallery-item {border-radius: 8px; overflow: hidden;} """ with gr.Blocks(title="Kontext-Dev LoRA Playground", theme=theme, css=custom_css) as demo: token_state = gr.State(value="") request_count_state = gr.State(value=0) # --- Authentication UI ------------------------------------------- if hasattr(gr, "LoginButton"): login_btn = gr.LoginButton() token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})") def _handle_login(login_data: Any): """Extract HF token from login payload returned by LoginButton.""" token: str = "" if isinstance(login_data, dict): token = login_data.get("access_token") or login_data.get("token") or "" elif isinstance(login_data, str): token = login_data status = "Logged in ✅ Unlimited" if token else f"Not logged in – using free quota (max {FREE_REQUESTS})" return token, status login_btn.login(_handle_login, outputs=[token_state, token_status]) else: # Fallback manual token input if LoginButton not available (local dev) with gr.Accordion("🔑 Paste your HF token (optional)", open=False): token_input = gr.Textbox(label="HF Token", type="password", placeholder="Paste your token here…") save_token_btn = gr.Button("Save token") token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})") # Handlers to store token def _save_token(tok): return tok or "" def _token_status(tok): return "Logged in ✅ Unlimited" if tok else f"Not logged in – using free quota (max {FREE_REQUESTS})" save_token_btn.click(_save_token, inputs=token_input, outputs=token_state) save_token_btn.click(_token_status, inputs=token_input, outputs=token_status) gr.Markdown( """ # Kontext-Dev LoRA Playground Select one of the available LoRAs from the dropdown, upload an image, tweak the prompt, and generate! """ ) with gr.Row(): # LEFT column – model selection + preview with gr.Column(scale=1): model_dropdown = gr.Dropdown( choices=LORA_MODELS, value=LORA_MODELS[0], label="Select LoRA model", ) preview_image = gr.Image(label="Sample image", interactive=False, height=256) trigger_text = gr.Textbox( label="Trigger phrase (suggested)", interactive=False, ) # RIGHT column – user inputs with gr.Column(scale=1): input_image = gr.Image( label="Input image", type="pil", ) prompt_box = gr.Textbox( label="Prompt", placeholder="Describe your transformation…", ) guidance = gr.Slider( minimum=1.0, maximum=10.0, value=2.5, step=0.1, label="Guidance scale", ) generate_btn = gr.Button("🚀 Generate") output_image = gr.Image(label="Output", interactive=False) quota_display = gr.Markdown(value=f"Free requests remaining: {FREE_REQUESTS}") # Showcase Gallery -------------------------------------------------- gr.Markdown("## ✨ Example outputs from selected LoRAs") example_gallery = gr.Gallery( label="Examples", columns=[4], height="auto", elem_id="example_gallery", ) gallery_data_state = gr.State([]) # ------------------------------------------------------------------ # Callbacks # ------------------------------------------------------------------ def _update_preview(model_id, _meta=metadata_cache): if model_id in _meta: img_url = _meta[model_id].get("image_url") trig = _meta[model_id].get("trigger_phrase") else: img_url, trig = fetch_preview_and_trigger(model_id) # Fallbacks if trig is None: trig = "(no trigger phrase provided)" return { preview_image: gr.Image(value=img_url) if img_url else gr.Image(value=None), trigger_text: gr.Textbox(value=trig), prompt_box: gr.Textbox(value=trig), } model_dropdown.change(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box]) generate_btn.click( fn=run_lora, inputs=[input_image, prompt_box, model_dropdown, guidance, token_state, request_count_state], outputs=[output_image, request_count_state, quota_display], ) # Helper to populate gallery once on launch def _load_gallery(_meta=metadata_cache): samples = [] for model_id in LORA_MODELS: info = _meta.get(model_id) if info and info.get("image_url"): samples.append([info["image_url"], model_id]) # shuffle and take first 12 random.shuffle(samples) return samples[:12], samples[:12] # Initialise preview and gallery on launch demo.load(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box]) demo.load(fn=_load_gallery, inputs=None, outputs=[example_gallery, gallery_data_state]) # Handle gallery click to update dropdown def _on_gallery_select(evt: gr.SelectData, data): idx = evt.index if idx is None or idx >= len(data): return gr.Dropdown.update() model_id = data[idx][1] return gr.Dropdown.update(value=model_id) example_gallery.select(_on_gallery_select, inputs=gallery_data_state, outputs=model_dropdown) return demo def main(): demo = build_interface() demo.launch() if __name__ == "__main__": main()