Create app.py

#1
by reach-vb HF Staff - opened
Files changed (1) hide show
  1. app.py +339 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import random
5
+ from functools import lru_cache
6
+ from typing import List, Tuple, Optional, Any
7
+
8
+ import gradio as gr
9
+ from huggingface_hub import InferenceClient, hf_hub_download
10
+
11
+ # -----------------------------------------------------------------------------
12
+ # Configuration
13
+ # -----------------------------------------------------------------------------
14
+
15
+ # LoRAs in the "Kontext Dev LoRAs" collection.
16
+ # NOTE: We hard-code the list for now. If the collection grows you can simply
17
+ # append new model IDs here.
18
+ LORA_MODELS: List[str] = [
19
+ # fal – original author
20
+ "fal/Watercolor-Art-Kontext-Dev-LoRA",
21
+ "fal/Pop-Art-Kontext-Dev-LoRA",
22
+ "fal/Pencil-Drawing-Kontext-Dev-LoRA",
23
+ "fal/Mosaic-Art-Kontext-Dev-LoRA",
24
+ "fal/Minimalist-Art-Kontext-Dev-LoRA",
25
+ "fal/Impressionist-Art-Kontext-Dev-LoRA",
26
+ "fal/Gouache-Art-Kontext-Dev-LoRA",
27
+ "fal/Expressive-Art-Kontext-Dev-LoRA",
28
+ "fal/Cubist-Art-Kontext-Dev-LoRA",
29
+ "fal/Collage-Art-Kontext-Dev-LoRA",
30
+ "fal/Charcoal-Art-Kontext-Dev-LoRA",
31
+ "fal/Acrylic-Art-Kontext-Dev-LoRA",
32
+ "fal/Abstract-Art-Kontext-Dev-LoRA",
33
+ "fal/Plushie-Kontext-Dev-LoRA",
34
+ "fal/Youtube-Thumbnails-Kontext-Dev-LoRA",
35
+ "fal/Broccoli-Hair-Kontext-Dev-LoRA",
36
+ "fal/Wojak-Kontext-Dev-LoRA",
37
+ "fal/3D-Game-Assets-Kontext-Dev-LoRA",
38
+ "fal/Realism-Detailer-Kontext-Dev-LoRA",
39
+ # community LoRAs
40
+ "gokaygokay/Pencil-Drawing-Kontext-Dev-LoRA",
41
+ "gokaygokay/Oil-Paint-Kontext-Dev-LoRA",
42
+ "gokaygokay/Watercolor-Kontext-Dev-LoRA",
43
+ "gokaygokay/Pastel-Flux-Kontext-Dev-LoRA",
44
+ "gokaygokay/Low-Poly-Kontext-Dev-LoRA",
45
+ "gokaygokay/Bronze-Sculpture-Kontext-Dev-LoRA",
46
+ "gokaygokay/Marble-Sculpture-Kontext-Dev-LoRA",
47
+ "gokaygokay/Light-Fix-Kontext-Dev-LoRA",
48
+ "gokaygokay/Fuse-it-Kontext-Dev-LoRA",
49
+ "ilkerzgi/Overlay-Kontext-Dev-LoRA",
50
+ ]
51
+
52
+ # Optional metadata cache file. Generated by `generate_lora_metadata.py`.
53
+ METADATA_FILE = "lora_metadata.json"
54
+
55
+
56
+ def _load_metadata() -> dict:
57
+ """Load cached preview/trigger data if the JSON file exists."""
58
+ if os.path.exists(METADATA_FILE):
59
+ try:
60
+ with open(METADATA_FILE, "r", encoding="utf-8") as fp:
61
+ return json.load(fp)
62
+ except Exception:
63
+ pass
64
+ return {}
65
+
66
+
67
+ # Token used for anonymous free quota
68
+ FREE_TOKEN_ENV = "HF_TOKEN"
69
+ FREE_REQUESTS = 10
70
+
71
+ # -----------------------------------------------------------------------------
72
+ # Utility helpers
73
+ # -----------------------------------------------------------------------------
74
+
75
+
76
+ @lru_cache(maxsize=None)
77
+ def get_client(token: str) -> InferenceClient:
78
+ """Return cached InferenceClient instance for supplied token."""
79
+ return InferenceClient(provider="fal-ai", api_key=token)
80
+
81
+
82
+ IMG_PATTERN = re.compile(r"!\[.*?\]\((.*?)\)")
83
+ TRIGGER_PATTERN = re.compile(r"[Tt]rigger[^:]*:\s*([^\n]+)")
84
+
85
+
86
+ @lru_cache(maxsize=None)
87
+ def fetch_preview_and_trigger(model_id: str) -> Tuple[Optional[str], Optional[str]]:
88
+ """Try to fetch a preview image URL and trigger phrase from the model card.
89
+
90
+ If unsuccessful, returns (None, None).
91
+ """
92
+ try:
93
+ # Download README.
94
+ readme_path = hf_hub_download(repo_id=model_id, filename="README.md")
95
+ except Exception:
96
+ return None, None
97
+
98
+ image_url: Optional[str] = None
99
+ trigger_phrase: Optional[str] = None
100
+
101
+ try:
102
+ with open(readme_path, "r", encoding="utf-8") as fp:
103
+ text = fp.read()
104
+ # First image in markdown → preview
105
+ if (m := IMG_PATTERN.search(text)) is not None:
106
+ img_path = m.group(1)
107
+ if img_path.startswith("http"):
108
+ image_url = img_path
109
+ else:
110
+ image_url = f"https://huggingface.co/{model_id}/resolve/main/{img_path.lstrip('./')}"
111
+ # Try to parse trigger phrase
112
+ if (m := TRIGGER_PATTERN.search(text)) is not None:
113
+ trigger_phrase = m.group(1).strip()
114
+ except Exception:
115
+ pass
116
+ return image_url, trigger_phrase
117
+
118
+
119
+ # -----------------------------------------------------------------------------
120
+ # Core inference function
121
+ # -----------------------------------------------------------------------------
122
+
123
+ def run_lora(
124
+ input_image, # bytes or PIL.Image
125
+ prompt: str,
126
+ model_id: str,
127
+ guidance_scale: float,
128
+ token: str | None,
129
+ req_count: int,
130
+ ):
131
+ """Execute image → image generation via selected LoRA."""
132
+ if input_image is None:
133
+ raise gr.Error("Please provide an input image.")
134
+
135
+ # Determine which token we will use
136
+ if token:
137
+ api_token = token
138
+ else:
139
+ free_token = os.getenv(FREE_TOKEN_ENV)
140
+ if free_token is None:
141
+ raise gr.Error("Service not configured for free usage. Please login.")
142
+
143
+ if req_count >= FREE_REQUESTS:
144
+ raise gr.Error("Free quota exceeded – please login with your own HF account to continue.")
145
+
146
+ api_token = free_token
147
+
148
+ client = get_client(api_token)
149
+ # Gradio delivers PIL.Image by default. InferenceClient accepts bytes.
150
+ if hasattr(input_image, "tobytes"):
151
+ import io
152
+ buf = io.BytesIO()
153
+ input_image.save(buf, format="PNG")
154
+ img_bytes = buf.getvalue()
155
+ elif isinstance(input_image, bytes):
156
+ img_bytes = input_image
157
+ else:
158
+ raise gr.Error("Unsupported image format.")
159
+
160
+ output = client.image_to_image(
161
+ img_bytes,
162
+ prompt=prompt,
163
+ model=model_id,
164
+ guidance_scale=guidance_scale,
165
+ )
166
+ # Update request count only if using free token
167
+ new_count = req_count if token else req_count + 1
168
+ return output, new_count, f"Free requests remaining: {max(0, FREE_REQUESTS - new_count)}" if not token else "Logged in ✅ Unlimited"
169
+
170
+
171
+ # -----------------------------------------------------------------------------
172
+ # UI assembly
173
+ # -----------------------------------------------------------------------------
174
+
175
+ def build_interface():
176
+ # Pre-load metadata into closure for fast look-ups.
177
+ metadata_cache = _load_metadata()
178
+
179
+ # Theme & CSS
180
+ theme = gr.themes.Soft(primary_hue="violet", secondary_hue="indigo")
181
+ custom_css = """
182
+ .gradio-container {max-width: 980px; margin: auto;}
183
+ .gallery-item {border-radius: 8px; overflow: hidden;}
184
+ """
185
+
186
+ with gr.Blocks(title="Kontext-Dev LoRA Playground", theme=theme, css=custom_css) as demo:
187
+ token_state = gr.State(value="")
188
+ request_count_state = gr.State(value=0)
189
+
190
+ # --- Authentication UI -------------------------------------------
191
+ if hasattr(gr, "LoginButton"):
192
+ login_btn = gr.LoginButton(label="🔐 Sign in with Hugging Face")
193
+ token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})")
194
+
195
+ def _handle_login(login_data: Any):
196
+ """Extract HF token from login payload returned by LoginButton."""
197
+ token: str = ""
198
+ if isinstance(login_data, dict):
199
+ token = login_data.get("access_token") or login_data.get("token") or ""
200
+ elif isinstance(login_data, str):
201
+ token = login_data
202
+ status = "Logged in ✅ Unlimited" if token else f"Not logged in – using free quota (max {FREE_REQUESTS})"
203
+ return token, status
204
+
205
+ login_btn.login(_handle_login, outputs=[token_state, token_status])
206
+ else:
207
+ # Fallback manual token input if LoginButton not available (local dev)
208
+ with gr.Accordion("🔑 Paste your HF token (optional)", open=False):
209
+ token_input = gr.Textbox(label="HF Token", type="password", placeholder="Paste your token here…")
210
+ save_token_btn = gr.Button("Save token")
211
+ token_status = gr.Markdown(value=f"Not logged in – using free quota (max {FREE_REQUESTS})")
212
+
213
+ # Handlers to store token
214
+ def _save_token(tok):
215
+ return tok or ""
216
+
217
+ def _token_status(tok):
218
+ return "Logged in ✅ Unlimited" if tok else f"Not logged in – using free quota (max {FREE_REQUESTS})"
219
+
220
+ save_token_btn.click(_save_token, inputs=token_input, outputs=token_state)
221
+ save_token_btn.click(_token_status, inputs=token_input, outputs=token_status)
222
+
223
+ gr.Markdown(
224
+ """
225
+ # Kontext-Dev LoRA Playground
226
+ Select one of the available LoRAs from the dropdown, upload an image, tweak the prompt, and generate!
227
+ """
228
+ )
229
+
230
+ with gr.Row():
231
+ # LEFT column – model selection + preview
232
+ with gr.Column(scale=1):
233
+ model_dropdown = gr.Dropdown(
234
+ choices=LORA_MODELS,
235
+ value=LORA_MODELS[0],
236
+ label="Select LoRA model",
237
+ )
238
+ preview_image = gr.Image(label="Sample image", interactive=False, height=256)
239
+ trigger_text = gr.Textbox(
240
+ label="Trigger phrase (suggested)",
241
+ interactive=False,
242
+ )
243
+
244
+ # RIGHT column – user inputs
245
+ with gr.Column(scale=1):
246
+ input_image = gr.Image(
247
+ label="Input image",
248
+ type="pil",
249
+ )
250
+ prompt_box = gr.Textbox(
251
+ label="Prompt",
252
+ placeholder="Describe your transformation…",
253
+ )
254
+ guidance = gr.Slider(
255
+ minimum=1.0,
256
+ maximum=10.0,
257
+ value=2.5,
258
+ step=0.1,
259
+ label="Guidance scale",
260
+ )
261
+ generate_btn = gr.Button("🚀 Generate")
262
+ output_image = gr.Image(label="Output", interactive=False)
263
+ quota_display = gr.Markdown(value=f"Free requests remaining: {FREE_REQUESTS}")
264
+
265
+ # Showcase Gallery --------------------------------------------------
266
+
267
+ gr.Markdown("## ✨ Example outputs from selected LoRAs")
268
+
269
+ example_gallery = gr.Gallery(
270
+ label="Examples",
271
+ columns=[4],
272
+ height="auto",
273
+ elem_id="example_gallery",
274
+ )
275
+
276
+ gallery_data_state = gr.State([])
277
+
278
+ # ------------------------------------------------------------------
279
+ # Callbacks
280
+ # ------------------------------------------------------------------
281
+
282
+ def _update_preview(model_id, _meta=metadata_cache):
283
+ if model_id in _meta:
284
+ img_url = _meta[model_id].get("image_url")
285
+ trig = _meta[model_id].get("trigger_phrase")
286
+ else:
287
+ img_url, trig = fetch_preview_and_trigger(model_id)
288
+ # Fallbacks
289
+ if trig is None:
290
+ trig = "(no trigger phrase provided)"
291
+ return {
292
+ preview_image: gr.Image(value=img_url) if img_url else gr.Image(value=None),
293
+ trigger_text: gr.Textbox(value=trig),
294
+ prompt_box: gr.Textbox(value=trig),
295
+ }
296
+
297
+ model_dropdown.change(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box])
298
+
299
+ generate_btn.click(
300
+ fn=run_lora,
301
+ inputs=[input_image, prompt_box, model_dropdown, guidance, token_state, request_count_state],
302
+ outputs=[output_image, request_count_state, quota_display],
303
+ )
304
+
305
+ # Helper to populate gallery once on launch
306
+ def _load_gallery(_meta=metadata_cache):
307
+ samples = []
308
+ for model_id in LORA_MODELS:
309
+ info = _meta.get(model_id)
310
+ if info and info.get("image_url"):
311
+ samples.append([info["image_url"], model_id])
312
+ # shuffle and take first 12
313
+ random.shuffle(samples)
314
+ return samples[:12], samples[:12]
315
+
316
+ # Initialise preview and gallery on launch
317
+ demo.load(_update_preview, inputs=model_dropdown, outputs=[preview_image, trigger_text, prompt_box])
318
+ demo.load(fn=_load_gallery, inputs=None, outputs=[example_gallery, gallery_data_state])
319
+
320
+ # Handle gallery click to update dropdown
321
+ def _on_gallery_select(evt: gr.SelectData, data):
322
+ idx = evt.index
323
+ if idx is None or idx >= len(data):
324
+ return gr.Dropdown.update()
325
+ model_id = data[idx][1]
326
+ return gr.Dropdown.update(value=model_id)
327
+
328
+ example_gallery.select(_on_gallery_select, inputs=gallery_data_state, outputs=model_dropdown)
329
+
330
+ return demo
331
+
332
+
333
+ def main():
334
+ demo = build_interface()
335
+ demo.launch()
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()