reach-vb's picture
reach-vb HF Staff
Update app.py (#3)
5f4e326 verified
import os
import re
import json
import random
from functools import lru_cache
from typing import List, Tuple, Optional, Any
import gradio as gr
from huggingface_hub import InferenceClient, hf_hub_download
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
# LoRAs in the "Kontext Dev LoRAs" collection.
# NOTE: We hard-code the list for now. If the collection grows you can simply
# append new model IDs here.
LORA_MODELS: List[str] = [
# fal – original author
"fal/Watercolor-Art-Kontext-Dev-LoRA",
"fal/Pop-Art-Kontext-Dev-LoRA",
"fal/Pencil-Drawing-Kontext-Dev-LoRA",
"fal/Mosaic-Art-Kontext-Dev-LoRA",
"fal/Minimalist-Art-Kontext-Dev-LoRA",
"fal/Impressionist-Art-Kontext-Dev-LoRA",
"fal/Gouache-Art-Kontext-Dev-LoRA",
"fal/Expressive-Art-Kontext-Dev-LoRA",
"fal/Cubist-Art-Kontext-Dev-LoRA",
"fal/Collage-Art-Kontext-Dev-LoRA",
"fal/Charcoal-Art-Kontext-Dev-LoRA",
"fal/Acrylic-Art-Kontext-Dev-LoRA",
"fal/Abstract-Art-Kontext-Dev-LoRA",
"fal/Plushie-Kontext-Dev-LoRA",
"fal/Youtube-Thumbnails-Kontext-Dev-LoRA",
"fal/Broccoli-Hair-Kontext-Dev-LoRA",
"fal/Wojak-Kontext-Dev-LoRA",
"fal/3D-Game-Assets-Kontext-Dev-LoRA",
"fal/Realism-Detailer-Kontext-Dev-LoRA",
# community LoRAs
"gokaygokay/Pencil-Drawing-Kontext-Dev-LoRA",
"gokaygokay/Oil-Paint-Kontext-Dev-LoRA",
"gokaygokay/Watercolor-Kontext-Dev-LoRA",
"gokaygokay/Pastel-Flux-Kontext-Dev-LoRA",
"gokaygokay/Low-Poly-Kontext-Dev-LoRA",
"gokaygokay/Bronze-Sculpture-Kontext-Dev-LoRA",
"gokaygokay/Marble-Sculpture-Kontext-Dev-LoRA",
"gokaygokay/Light-Fix-Kontext-Dev-LoRA",
"gokaygokay/Fuse-it-Kontext-Dev-LoRA",
"ilkerzgi/Overlay-Kontext-Dev-LoRA",
]
# Optional metadata cache file. Generated by `generate_lora_metadata.py`.
METADATA_FILE = "lora_metadata.json"
def _load_metadata() -> dict:
"""Load cached preview/trigger data if the JSON file exists."""
if os.path.exists(METADATA_FILE):
try:
with open(METADATA_FILE, "r", encoding="utf-8") as fp:
return json.load(fp)
except Exception:
pass
return {}
# Token used for anonymous free quota
FREE_TOKEN_ENV = "HF_TOKEN"
FREE_REQUESTS = 10
# -----------------------------------------------------------------------------
# Utility helpers
# -----------------------------------------------------------------------------
@lru_cache(maxsize=None)
def get_client(token: str) -> InferenceClient:
"""Return cached InferenceClient instance for supplied token."""
return InferenceClient(provider="fal-ai", api_key=token)
IMG_PATTERN = re.compile(r"!\[.*?\]\((.*?)\)")
TRIGGER_PATTERN = re.compile(r"[Tt]rigger[^:]*:\s*([^\n]+)")
@lru_cache(maxsize=None)
def fetch_preview_and_trigger(model_id: str) -> Tuple[Optional[str], Optional[str]]:
"""Try to fetch a preview image URL and trigger phrase from the model card.
If unsuccessful, returns (None, None).
"""
try:
# Download README.
readme_path = hf_hub_download(repo_id=model_id, filename="README.md")
except Exception:
return None, None
image_url: Optional[str] = None
trigger_phrase: Optional[str] = None
try:
with open(readme_path, "r", encoding="utf-8") as fp:
text = fp.read()
# First image in markdown → preview
if (m := IMG_PATTERN.search(text)) is not None:
img_path = m.group(1)
if img_path.startswith("http"):
image_url = img_path
else:
image_url = f"https://huggingface.co/{model_id}/resolve/main/{img_path.lstrip('./')}"
# Try to parse trigger phrase
if (m := TRIGGER_PATTERN.search(text)) is not None:
trigger_phrase = m.group(1).strip()
except Exception:
pass
return image_url, trigger_phrase
# -----------------------------------------------------------------------------
# Core inference function
# -----------------------------------------------------------------------------
def run_lora(
input_image, # bytes or PIL.Image
prompt: str,
model_id: str,
guidance_scale: float,
token: str | None,
req_count: int,
):
"""Execute image → image generation via selected LoRA."""
if input_image is None:
raise gr.Error("Please provide an input image.")
# Determine which token we will use
if token:
api_token = token
else:
free_token = os.getenv(FREE_TOKEN_ENV)
if free_token is None:
raise gr.Error("Service not configured for free usage. Please login.")
if req_count >= FREE_REQUESTS:
raise gr.Error("Free quota exceeded – please login with your own HF account to continue.")
api_token = free_token
client = get_client(api_token)
# Gradio delivers PIL.Image by default. InferenceClient accepts bytes.
if hasattr(input_image, "tobytes"):
import io
buf = io.BytesIO()
input_image.save(buf, format="PNG")
img_bytes = buf.getvalue()
elif isinstance(input_image, bytes):
img_bytes = input_image
else:
raise gr.Error("Unsupported image format.")
output = client.image_to_image(
img_bytes,
prompt=prompt,
model=model_id,
guidance_scale=guidance_scale,
)
# Update request count only if using free token
new_count = req_count if token else req_count + 1
return output, new_count, f"Free requests remaining: {max(0, FREE_REQUESTS - new_count)}" if not token else "Logged in ✅ Unlimited"
# -----------------------------------------------------------------------------
# UI assembly
# -----------------------------------------------------------------------------
def build_interface():
# Pre-load metadata into closure for fast look-ups.
metadata_cache = _load_metadata()
# Theme & CSS
theme = gr.themes.Soft(primary_hue="violet", secondary_hue="indigo")
custom_css = """
.gradio-container {max-width: 980px; margin: auto;}
.gallery-item {border-radius: 8px; overflow: hidden;}
"""
with gr.Blocks(title="Kontext-Dev LoRA Playground", theme=theme, css=custom_css) as demo:
token_state = gr.State(value="")
request_count_state = gr.State(value=0)
# --- Authentication UI -------------------------------------------
if hasattr(gr, "LoginButton"):
login_btn = gr.LoginButton()
token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})")
def _handle_login(login_data: Any):
"""Extract HF token from login payload returned by LoginButton."""
token: str = ""
if isinstance(login_data, dict):
token = login_data.get("access_token") or login_data.get("token") or ""
elif isinstance(login_data, str):
token = login_data
status = "Logged in ✅ Unlimited" if token else f"Not logged in – using free quota (max {FREE_REQUESTS})"
return token, status
login_btn.login(_handle_login, outputs=[token_state, token_status])
else:
# Fallback manual token input if LoginButton not available (local dev)
with gr.Accordion("🔑 Paste your HF token (optional)", open=False):
token_input = gr.Textbox(label="HF Token", type="password", placeholder="Paste your token here…")
save_token_btn = gr.Button("Save token")
token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})")
# Handlers to store token
def _save_token(tok):
return tok or ""
def _token_status(tok):
return "Logged in ✅ Unlimited" if tok else f"Not logged in – using free quota (max {FREE_REQUESTS})"
save_token_btn.click(_save_token, inputs=token_input, outputs=token_state)
save_token_btn.click(_token_status, inputs=token_input, outputs=token_status)
gr.Markdown(
"""
# Kontext-Dev LoRA Playground
Select one of the available LoRAs from the dropdown, upload an image, tweak the prompt, and generate!
"""
)
with gr.Row():
# LEFT column – model selection + preview
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=LORA_MODELS,
value=LORA_MODELS[0],
label="Select LoRA model",
)
preview_image = gr.Image(label="Sample image", interactive=False, height=256)
trigger_text = gr.Textbox(
label="Trigger phrase (suggested)",
interactive=False,
)
# RIGHT column – user inputs
with gr.Column(scale=1):
input_image = gr.Image(
label="Input image",
type="pil",
)
prompt_box = gr.Textbox(
label="Prompt",
placeholder="Describe your transformation…",
)
guidance = gr.Slider(
minimum=1.0,
maximum=10.0,
value=2.5,
step=0.1,
label="Guidance scale",
)
generate_btn = gr.Button("🚀 Generate")
output_image = gr.Image(label="Output", interactive=False)
quota_display = gr.Markdown(value=f"Free requests remaining: {FREE_REQUESTS}")
# Showcase Gallery --------------------------------------------------
gr.Markdown("## ✨ Example outputs from selected LoRAs")
example_gallery = gr.Gallery(
label="Examples",
columns=[4],
height="auto",
elem_id="example_gallery",
)
gallery_data_state = gr.State([])
# ------------------------------------------------------------------
# Callbacks
# ------------------------------------------------------------------
def _update_preview(model_id, _meta=metadata_cache):
if model_id in _meta:
img_url = _meta[model_id].get("image_url")
trig = _meta[model_id].get("trigger_phrase")
else:
img_url, trig = fetch_preview_and_trigger(model_id)
# Fallbacks
if trig is None:
trig = "(no trigger phrase provided)"
return {
preview_image: gr.Image(value=img_url) if img_url else gr.Image(value=None),
trigger_text: gr.Textbox(value=trig),
prompt_box: gr.Textbox(value=trig),
}
model_dropdown.change(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box])
generate_btn.click(
fn=run_lora,
inputs=[input_image, prompt_box, model_dropdown, guidance, token_state, request_count_state],
outputs=[output_image, request_count_state, quota_display],
)
# Helper to populate gallery once on launch
def _load_gallery(_meta=metadata_cache):
samples = []
for model_id in LORA_MODELS:
info = _meta.get(model_id)
if info and info.get("image_url"):
samples.append([info["image_url"], model_id])
# shuffle and take first 12
random.shuffle(samples)
return samples[:12], samples[:12]
# Initialise preview and gallery on launch
demo.load(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box])
demo.load(fn=_load_gallery, inputs=None, outputs=[example_gallery, gallery_data_state])
# Handle gallery click to update dropdown
def _on_gallery_select(evt: gr.SelectData, data):
idx = evt.index
if idx is None or idx >= len(data):
return gr.Dropdown.update()
model_id = data[idx][1]
return gr.Dropdown.update(value=model_id)
example_gallery.select(_on_gallery_select, inputs=gallery_data_state, outputs=model_dropdown)
return demo
def main():
demo = build_interface()
demo.launch()
if __name__ == "__main__":
main()