Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
import random | |
from functools import lru_cache | |
from typing import List, Tuple, Optional, Any | |
import gradio as gr | |
from huggingface_hub import InferenceClient, hf_hub_download | |
# ----------------------------------------------------------------------------- | |
# Configuration | |
# ----------------------------------------------------------------------------- | |
# LoRAs in the "Kontext Dev LoRAs" collection. | |
# NOTE: We hard-code the list for now. If the collection grows you can simply | |
# append new model IDs here. | |
LORA_MODELS: List[str] = [ | |
# fal – original author | |
"fal/Watercolor-Art-Kontext-Dev-LoRA", | |
"fal/Pop-Art-Kontext-Dev-LoRA", | |
"fal/Pencil-Drawing-Kontext-Dev-LoRA", | |
"fal/Mosaic-Art-Kontext-Dev-LoRA", | |
"fal/Minimalist-Art-Kontext-Dev-LoRA", | |
"fal/Impressionist-Art-Kontext-Dev-LoRA", | |
"fal/Gouache-Art-Kontext-Dev-LoRA", | |
"fal/Expressive-Art-Kontext-Dev-LoRA", | |
"fal/Cubist-Art-Kontext-Dev-LoRA", | |
"fal/Collage-Art-Kontext-Dev-LoRA", | |
"fal/Charcoal-Art-Kontext-Dev-LoRA", | |
"fal/Acrylic-Art-Kontext-Dev-LoRA", | |
"fal/Abstract-Art-Kontext-Dev-LoRA", | |
"fal/Plushie-Kontext-Dev-LoRA", | |
"fal/Youtube-Thumbnails-Kontext-Dev-LoRA", | |
"fal/Broccoli-Hair-Kontext-Dev-LoRA", | |
"fal/Wojak-Kontext-Dev-LoRA", | |
"fal/3D-Game-Assets-Kontext-Dev-LoRA", | |
"fal/Realism-Detailer-Kontext-Dev-LoRA", | |
# community LoRAs | |
"gokaygokay/Pencil-Drawing-Kontext-Dev-LoRA", | |
"gokaygokay/Oil-Paint-Kontext-Dev-LoRA", | |
"gokaygokay/Watercolor-Kontext-Dev-LoRA", | |
"gokaygokay/Pastel-Flux-Kontext-Dev-LoRA", | |
"gokaygokay/Low-Poly-Kontext-Dev-LoRA", | |
"gokaygokay/Bronze-Sculpture-Kontext-Dev-LoRA", | |
"gokaygokay/Marble-Sculpture-Kontext-Dev-LoRA", | |
"gokaygokay/Light-Fix-Kontext-Dev-LoRA", | |
"gokaygokay/Fuse-it-Kontext-Dev-LoRA", | |
"ilkerzgi/Overlay-Kontext-Dev-LoRA", | |
] | |
# Optional metadata cache file. Generated by `generate_lora_metadata.py`. | |
METADATA_FILE = "lora_metadata.json" | |
def _load_metadata() -> dict: | |
"""Load cached preview/trigger data if the JSON file exists.""" | |
if os.path.exists(METADATA_FILE): | |
try: | |
with open(METADATA_FILE, "r", encoding="utf-8") as fp: | |
return json.load(fp) | |
except Exception: | |
pass | |
return {} | |
# Token used for anonymous free quota | |
FREE_TOKEN_ENV = "HF_TOKEN" | |
FREE_REQUESTS = 10 | |
# ----------------------------------------------------------------------------- | |
# Utility helpers | |
# ----------------------------------------------------------------------------- | |
def get_client(token: str) -> InferenceClient: | |
"""Return cached InferenceClient instance for supplied token.""" | |
return InferenceClient(provider="fal-ai", api_key=token) | |
IMG_PATTERN = re.compile(r"!\[.*?\]\((.*?)\)") | |
TRIGGER_PATTERN = re.compile(r"[Tt]rigger[^:]*:\s*([^\n]+)") | |
def fetch_preview_and_trigger(model_id: str) -> Tuple[Optional[str], Optional[str]]: | |
"""Try to fetch a preview image URL and trigger phrase from the model card. | |
If unsuccessful, returns (None, None). | |
""" | |
try: | |
# Download README. | |
readme_path = hf_hub_download(repo_id=model_id, filename="README.md") | |
except Exception: | |
return None, None | |
image_url: Optional[str] = None | |
trigger_phrase: Optional[str] = None | |
try: | |
with open(readme_path, "r", encoding="utf-8") as fp: | |
text = fp.read() | |
# First image in markdown → preview | |
if (m := IMG_PATTERN.search(text)) is not None: | |
img_path = m.group(1) | |
if img_path.startswith("http"): | |
image_url = img_path | |
else: | |
image_url = f"https://huggingface.co/{model_id}/resolve/main/{img_path.lstrip('./')}" | |
# Try to parse trigger phrase | |
if (m := TRIGGER_PATTERN.search(text)) is not None: | |
trigger_phrase = m.group(1).strip() | |
except Exception: | |
pass | |
return image_url, trigger_phrase | |
# ----------------------------------------------------------------------------- | |
# Core inference function | |
# ----------------------------------------------------------------------------- | |
def run_lora( | |
input_image, # bytes or PIL.Image | |
prompt: str, | |
model_id: str, | |
guidance_scale: float, | |
token: str | None, | |
req_count: int, | |
): | |
"""Execute image → image generation via selected LoRA.""" | |
if input_image is None: | |
raise gr.Error("Please provide an input image.") | |
# Determine which token we will use | |
if token: | |
api_token = token | |
else: | |
free_token = os.getenv(FREE_TOKEN_ENV) | |
if free_token is None: | |
raise gr.Error("Service not configured for free usage. Please login.") | |
if req_count >= FREE_REQUESTS: | |
raise gr.Error("Free quota exceeded – please login with your own HF account to continue.") | |
api_token = free_token | |
client = get_client(api_token) | |
# Gradio delivers PIL.Image by default. InferenceClient accepts bytes. | |
if hasattr(input_image, "tobytes"): | |
import io | |
buf = io.BytesIO() | |
input_image.save(buf, format="PNG") | |
img_bytes = buf.getvalue() | |
elif isinstance(input_image, bytes): | |
img_bytes = input_image | |
else: | |
raise gr.Error("Unsupported image format.") | |
output = client.image_to_image( | |
img_bytes, | |
prompt=prompt, | |
model=model_id, | |
guidance_scale=guidance_scale, | |
) | |
# Update request count only if using free token | |
new_count = req_count if token else req_count + 1 | |
return output, new_count, f"Free requests remaining: {max(0, FREE_REQUESTS - new_count)}" if not token else "Logged in ✅ Unlimited" | |
# ----------------------------------------------------------------------------- | |
# UI assembly | |
# ----------------------------------------------------------------------------- | |
def build_interface(): | |
# Pre-load metadata into closure for fast look-ups. | |
metadata_cache = _load_metadata() | |
# Theme & CSS | |
theme = gr.themes.Soft(primary_hue="violet", secondary_hue="indigo") | |
custom_css = """ | |
.gradio-container {max-width: 980px; margin: auto;} | |
.gallery-item {border-radius: 8px; overflow: hidden;} | |
""" | |
with gr.Blocks(title="Kontext-Dev LoRA Playground", theme=theme, css=custom_css) as demo: | |
token_state = gr.State(value="") | |
request_count_state = gr.State(value=0) | |
# --- Authentication UI ------------------------------------------- | |
if hasattr(gr, "LoginButton"): | |
login_btn = gr.LoginButton() | |
token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})") | |
def _handle_login(login_data: Any): | |
"""Extract HF token from login payload returned by LoginButton.""" | |
token: str = "" | |
if isinstance(login_data, dict): | |
token = login_data.get("access_token") or login_data.get("token") or "" | |
elif isinstance(login_data, str): | |
token = login_data | |
status = "Logged in ✅ Unlimited" if token else f"Not logged in – using free quota (max {FREE_REQUESTS})" | |
return token, status | |
login_btn.login(_handle_login, outputs=[token_state, token_status]) | |
else: | |
# Fallback manual token input if LoginButton not available (local dev) | |
with gr.Accordion("🔑 Paste your HF token (optional)", open=False): | |
token_input = gr.Textbox(label="HF Token", type="password", placeholder="Paste your token here…") | |
save_token_btn = gr.Button("Save token") | |
token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})") | |
# Handlers to store token | |
def _save_token(tok): | |
return tok or "" | |
def _token_status(tok): | |
return "Logged in ✅ Unlimited" if tok else f"Not logged in – using free quota (max {FREE_REQUESTS})" | |
save_token_btn.click(_save_token, inputs=token_input, outputs=token_state) | |
save_token_btn.click(_token_status, inputs=token_input, outputs=token_status) | |
gr.Markdown( | |
""" | |
# Kontext-Dev LoRA Playground | |
Select one of the available LoRAs from the dropdown, upload an image, tweak the prompt, and generate! | |
""" | |
) | |
with gr.Row(): | |
# LEFT column – model selection + preview | |
with gr.Column(scale=1): | |
model_dropdown = gr.Dropdown( | |
choices=LORA_MODELS, | |
value=LORA_MODELS[0], | |
label="Select LoRA model", | |
) | |
preview_image = gr.Image(label="Sample image", interactive=False, height=256) | |
trigger_text = gr.Textbox( | |
label="Trigger phrase (suggested)", | |
interactive=False, | |
) | |
# RIGHT column – user inputs | |
with gr.Column(scale=1): | |
input_image = gr.Image( | |
label="Input image", | |
type="pil", | |
) | |
prompt_box = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe your transformation…", | |
) | |
guidance = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
value=2.5, | |
step=0.1, | |
label="Guidance scale", | |
) | |
generate_btn = gr.Button("🚀 Generate") | |
output_image = gr.Image(label="Output", interactive=False) | |
quota_display = gr.Markdown(value=f"Free requests remaining: {FREE_REQUESTS}") | |
# Showcase Gallery -------------------------------------------------- | |
gr.Markdown("## ✨ Example outputs from selected LoRAs") | |
example_gallery = gr.Gallery( | |
label="Examples", | |
columns=[4], | |
height="auto", | |
elem_id="example_gallery", | |
) | |
gallery_data_state = gr.State([]) | |
# ------------------------------------------------------------------ | |
# Callbacks | |
# ------------------------------------------------------------------ | |
def _update_preview(model_id, _meta=metadata_cache): | |
if model_id in _meta: | |
img_url = _meta[model_id].get("image_url") | |
trig = _meta[model_id].get("trigger_phrase") | |
else: | |
img_url, trig = fetch_preview_and_trigger(model_id) | |
# Fallbacks | |
if trig is None: | |
trig = "(no trigger phrase provided)" | |
return { | |
preview_image: gr.Image(value=img_url) if img_url else gr.Image(value=None), | |
trigger_text: gr.Textbox(value=trig), | |
prompt_box: gr.Textbox(value=trig), | |
} | |
model_dropdown.change(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box]) | |
generate_btn.click( | |
fn=run_lora, | |
inputs=[input_image, prompt_box, model_dropdown, guidance, token_state, request_count_state], | |
outputs=[output_image, request_count_state, quota_display], | |
) | |
# Helper to populate gallery once on launch | |
def _load_gallery(_meta=metadata_cache): | |
samples = [] | |
for model_id in LORA_MODELS: | |
info = _meta.get(model_id) | |
if info and info.get("image_url"): | |
samples.append([info["image_url"], model_id]) | |
# shuffle and take first 12 | |
random.shuffle(samples) | |
return samples[:12], samples[:12] | |
# Initialise preview and gallery on launch | |
demo.load(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box]) | |
demo.load(fn=_load_gallery, inputs=None, outputs=[example_gallery, gallery_data_state]) | |
# Handle gallery click to update dropdown | |
def _on_gallery_select(evt: gr.SelectData, data): | |
idx = evt.index | |
if idx is None or idx >= len(data): | |
return gr.Dropdown.update() | |
model_id = data[idx][1] | |
return gr.Dropdown.update(value=model_id) | |
example_gallery.select(_on_gallery_select, inputs=gallery_data_state, outputs=model_dropdown) | |
return demo | |
def main(): | |
demo = build_interface() | |
demo.launch() | |
if __name__ == "__main__": | |
main() |