Spaces:
Running
Running
Commit
·
9b0c0fa
1
Parent(s):
13af073
ONS2
Browse files
app.py
CHANGED
@@ -35,14 +35,12 @@ def load_model(hf_token):
|
|
35 |
|
36 |
try:
|
37 |
# Try different model versions from smallest to largest
|
38 |
-
# Prioritize instruction-tuned models
|
39 |
model_options = [
|
40 |
"google/gemma-2b-it",
|
41 |
"google/gemma-7b-it",
|
42 |
"google/gemma-2b",
|
43 |
"google/gemma-7b",
|
44 |
-
#
|
45 |
-
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
46 |
]
|
47 |
|
48 |
print(f"Attempting to load models with token starting with: {hf_token[:5]}...")
|
@@ -51,10 +49,8 @@ def load_model(hf_token):
|
|
51 |
try:
|
52 |
print(f"\n--- Attempting to load model: {model_name} ---")
|
53 |
is_gemma = "gemma" in model_name.lower()
|
54 |
-
|
55 |
-
current_token = hf_token if is_gemma else None # Only use token for Gemma models
|
56 |
|
57 |
-
# Load tokenizer
|
58 |
print("Loading tokenizer...")
|
59 |
global_tokenizer = AutoTokenizer.from_pretrained(
|
60 |
model_name,
|
@@ -62,13 +58,11 @@ def load_model(hf_token):
|
|
62 |
)
|
63 |
print("Tokenizer loaded successfully.")
|
64 |
|
65 |
-
# Load model
|
66 |
print(f"Loading model {model_name}...")
|
67 |
global_model = AutoModelForCausalLM.from_pretrained(
|
68 |
model_name,
|
69 |
-
# torch_dtype=torch.bfloat16, # Use bfloat16 for better performance/compatibility if available - fallback to float16 if needed
|
70 |
torch_dtype=torch.float16, # Using float16 for broader compatibility
|
71 |
-
device_map="auto",
|
72 |
token=current_token
|
73 |
)
|
74 |
print(f"Model {model_name} loaded successfully!")
|
@@ -76,34 +70,26 @@ def load_model(hf_token):
|
|
76 |
model_loaded = True
|
77 |
loaded_model_name = model_name
|
78 |
loaded_successfully = True
|
79 |
-
tabs_update = gr.Tabs.update(visible=True)
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
|
85 |
except ImportError as import_err:
|
86 |
-
|
87 |
-
|
88 |
-
continue # Try next model
|
89 |
except Exception as specific_e:
|
90 |
print(f"Failed to load {model_name}: {specific_e}")
|
91 |
-
|
92 |
-
if "401 Client Error" in str(specific_e) and is_gemma:
|
93 |
print("Authentication error likely. Check token and license agreement.")
|
94 |
-
# Don't immediately fail, try next model
|
95 |
-
elif "requires you to be logged in" in str(specific_e) and is_gemma:
|
96 |
-
print("Authentication error likely. Check token and license agreement.")
|
97 |
-
# Don't immediately fail, try next model
|
98 |
-
# Continue to the next model option
|
99 |
continue
|
100 |
|
101 |
-
# If loop finishes without loading
|
102 |
if not loaded_successfully:
|
103 |
model_loaded = False
|
104 |
loaded_model_name = "None"
|
105 |
print("Could not load any model version.")
|
106 |
-
return "❌ Could not load any model. Please check your token
|
107 |
|
108 |
except Exception as e:
|
109 |
model_loaded = False
|
@@ -111,16 +97,14 @@ def load_model(hf_token):
|
|
111 |
error_msg = str(e)
|
112 |
print(f"Error in load_model: {error_msg}")
|
113 |
traceback.print_exc()
|
114 |
-
|
115 |
if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg :
|
116 |
-
return "❌ Authentication failed.
|
117 |
else:
|
118 |
-
return f"❌
|
119 |
|
120 |
|
121 |
def generate_prompt(task_type, **kwargs):
|
122 |
"""Generate appropriate prompts based on task type and parameters"""
|
123 |
-
# Using a dictionary-based approach for cleaner prompt generation
|
124 |
prompts = {
|
125 |
"creative": "Write a {style} about {topic}. Be creative and engaging.",
|
126 |
"informational": "Write an {format_type} about {topic}. Be clear, factual, and informative.",
|
@@ -138,24 +122,17 @@ def generate_prompt(task_type, **kwargs):
|
|
138 |
"classify": "Classify the following text into one of these categories: {categories}\n\nText: {text}\n\nCategory:",
|
139 |
"data_extract": "Extract the following data points ({data_points}) from the text below:\n\nText: {text}\n\nExtracted Data:",
|
140 |
}
|
141 |
-
|
142 |
prompt_template = prompts.get(task_type)
|
143 |
if prompt_template:
|
144 |
try:
|
145 |
-
# Prepare kwargs safely for formatting
|
146 |
-
# Find placeholders like {key}
|
147 |
keys_in_template = [k[1:-1] for k in prompt_template.split('{') if '}' in k for k in [k.split('}')[0]]]
|
148 |
-
final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in keys_in_template}
|
149 |
-
|
150 |
-
# Add any extra kwargs provided that weren't in the template (e.g., for 'custom' type)
|
151 |
-
final_kwargs.update(kwargs)
|
152 |
-
|
153 |
return prompt_template.format(**final_kwargs)
|
154 |
except KeyError as e:
|
155 |
print(f"Warning: Missing key for prompt template '{task_type}': {e}")
|
156 |
-
return kwargs.get("prompt", f"Generate text based on: {kwargs}")
|
157 |
else:
|
158 |
-
# Fallback for custom or undefined task types
|
159 |
return kwargs.get("prompt", "Generate text based on the input.")
|
160 |
|
161 |
|
@@ -169,689 +146,435 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
169 |
print(f"Prompt (start): {prompt[:150]}...")
|
170 |
|
171 |
if not model_loaded or global_model is None or global_tokenizer is None:
|
172 |
-
print("Model not loaded error.")
|
173 |
return "⚠️ Model not loaded. Please authenticate first."
|
174 |
-
|
175 |
if not prompt:
|
176 |
return "⚠️ Please enter a prompt or configure a task."
|
177 |
|
178 |
try:
|
179 |
-
|
180 |
-
# Simple check based on model name conventions
|
181 |
if loaded_model_name and ("it" in loaded_model_name.lower() or "instruct" in loaded_model_name.lower() or "chat" in loaded_model_name.lower()):
|
182 |
-
# Simple chat structure assumed by many instruction models
|
183 |
-
# Using Gemma's specific format if it's a Gemma IT model
|
184 |
if "gemma" in loaded_model_name.lower():
|
|
|
185 |
chat_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
|
|
|
|
|
|
186 |
else: # Generic instruction format
|
187 |
chat_prompt = f"User: {prompt}\nAssistant:"
|
188 |
-
else:
|
189 |
-
# Base models might not need specific turn indicators
|
190 |
-
chat_prompt = prompt
|
191 |
|
192 |
inputs = global_tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=True).to(global_model.device)
|
193 |
input_length = inputs.input_ids.shape[1]
|
194 |
print(f"Input token length: {input_length}")
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
199 |
|
200 |
generation_args = {
|
201 |
"input_ids": inputs.input_ids,
|
202 |
-
"attention_mask": inputs.attention_mask,
|
203 |
"max_new_tokens": effective_max_new_tokens,
|
204 |
"do_sample": True,
|
205 |
-
"temperature": float(temperature),
|
206 |
-
"top_p": float(top_p),
|
207 |
-
"pad_token_id":
|
208 |
}
|
209 |
|
210 |
print(f"Generation args: {generation_args}")
|
211 |
|
212 |
-
|
213 |
-
with torch.no_grad(): # Disable gradient calculation for inference
|
214 |
outputs = global_model.generate(**generation_args)
|
215 |
|
216 |
-
# Decode only the newly generated tokens
|
217 |
generated_ids = outputs[0, input_length:]
|
218 |
generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True)
|
219 |
|
220 |
print(f"Generated text length: {len(generated_text)}")
|
221 |
print(f"Generated text (start): {generated_text[:150]}...")
|
222 |
-
return generated_text.strip()
|
223 |
|
224 |
except Exception as e:
|
225 |
error_msg = str(e)
|
226 |
print(f"Generation error: {error_msg}")
|
227 |
-
print(f"Error type: {type(e)}")
|
228 |
traceback.print_exc()
|
229 |
-
# Check for common CUDA errors
|
230 |
if "CUDA out of memory" in error_msg:
|
231 |
-
return f"❌ Error: CUDA out of memory. Try reducing 'Max New Tokens' or using a smaller model
|
232 |
-
elif "probability tensor contains nan" in error_msg:
|
233 |
-
return f"❌ Error: Generation failed (
|
234 |
else:
|
235 |
-
return f"❌ Error during text generation: {error_msg}
|
|
|
|
|
236 |
|
237 |
-
# Create parameters UI component (reusable function)
|
238 |
def create_parameter_ui():
|
239 |
with gr.Accordion("✨ Generation Parameters", open=False):
|
240 |
with gr.Row():
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
maximum=2048, # Set a reasonable max limit
|
245 |
-
value=512, # Default to a moderate length
|
246 |
-
step=64,
|
247 |
-
label="Max New Tokens",
|
248 |
-
info="Max number of tokens to generate.",
|
249 |
-
elem_id="max_new_tokens_slider"
|
250 |
-
)
|
251 |
-
temperature = gr.Slider(
|
252 |
-
minimum=0.1, # Avoid 0 which disables sampling
|
253 |
-
maximum=1.5,
|
254 |
-
value=0.7,
|
255 |
-
step=0.1,
|
256 |
-
label="Temperature",
|
257 |
-
info="Controls randomness. Lower is focused, higher is diverse.",
|
258 |
-
elem_id="temperature_slider"
|
259 |
-
)
|
260 |
-
top_p = gr.Slider(
|
261 |
-
minimum=0.1,
|
262 |
-
maximum=1.0, # Can be 1.0
|
263 |
-
value=0.9,
|
264 |
-
step=0.05,
|
265 |
-
label="Top-P (Nucleus Sampling)",
|
266 |
-
info="Considers tokens with cumulative probability >= top_p.",
|
267 |
-
elem_id="top_p_slider"
|
268 |
-
)
|
269 |
return [max_new_tokens, temperature, top_p]
|
270 |
|
|
|
|
|
|
|
271 |
# --- Gradio Interface ---
|
272 |
-
# Use the soft theme for a clean look, allow light/dark switching
|
273 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, title="Gemma Capabilities Demo") as demo:
|
274 |
|
275 |
# Header
|
276 |
gr.Markdown(
|
277 |
"""
|
278 |
-
<div style="text-align: center; margin-bottom: 20px;">
|
279 |
-
|
280 |
-
|
281 |
-
</h1>
|
282 |
-
<p style="font-size: 1.1em; color: #555;">
|
283 |
-
Explore the text generation capabilities of Google's Gemma models (or a fallback).
|
284 |
-
</p>
|
285 |
-
<p style="font-size: 0.9em; color: #777;">
|
286 |
-
Requires a Hugging Face token with access to Gemma models.
|
287 |
-
<a href="https://huggingface.co/google/gemma-7b-it" target="_blank">[Accept Gemma License Here]</a>
|
288 |
-
</p>
|
289 |
-
</div>
|
290 |
-
"""
|
291 |
)
|
292 |
|
293 |
-
# --- Authentication
|
294 |
-
#
|
295 |
-
|
296 |
-
gr.Markdown("### 🔑 Authentication") # Added heading inside group
|
297 |
with gr.Row():
|
298 |
with gr.Column(scale=4):
|
299 |
-
hf_token = gr.Textbox(
|
300 |
-
label="Hugging Face Token",
|
301 |
-
placeholder="Paste your HF token here (hf_...)",
|
302 |
-
type="password",
|
303 |
-
value=DEFAULT_HF_TOKEN,
|
304 |
-
info="Get your token from https://huggingface.co/settings/tokens",
|
305 |
-
elem_id="hf_token_input"
|
306 |
-
)
|
307 |
with gr.Column(scale=1, min_width=150):
|
308 |
-
auth_button = gr.Button("Load Model", variant="primary"
|
309 |
-
|
310 |
-
auth_status = gr.Markdown("ℹ️ Enter your Hugging Face token and click 'Load Model'. This might take a minute.", elem_id="auth_status")
|
311 |
-
# Add instructions on getting token inside the auth group
|
312 |
gr.Markdown(
|
313 |
-
""
|
314 |
-
|
315 |
-
1. Go to [Hugging Face Token Settings](https://huggingface.co/settings/tokens)
|
316 |
-
2. Create a new token with **read** access.
|
317 |
-
3. Ensure you've accepted the [Gemma model license](https://huggingface.co/google/gemma-7b-it) on the model page.
|
318 |
-
"""
|
319 |
)
|
320 |
|
321 |
-
|
322 |
-
# --- Main Content Tabs (Initially Hidden) ---
|
323 |
-
# Define the tabs variable here
|
324 |
with gr.Tabs(elem_id="main_tabs", visible=False) as tabs:
|
325 |
|
326 |
# --- Text Generation Tab ---
|
327 |
-
with gr.TabItem("📝 Creative & Informational"
|
328 |
-
with gr.Row(
|
329 |
-
# Input Column
|
330 |
with gr.Column(scale=1):
|
331 |
-
gr.Markdown("
|
332 |
-
text_gen_type = gr.Radio(
|
333 |
-
|
334 |
-
label="
|
335 |
-
value="
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
style = gr.Dropdown(["short story", "poem", "script", "song lyrics", "joke", "dialogue"], label="Style", value="short story", elem_id="creative_style")
|
342 |
-
creative_topic = gr.Textbox(label="Topic", placeholder="e.g., a lonely astronaut on Mars", value="a robot discovering music", elem_id="creative_topic", lines=2)
|
343 |
-
|
344 |
-
with gr.Group(visible=False, elem_id="info_options") as info_options:
|
345 |
-
format_type = gr.Dropdown(["article", "summary", "explanation", "report", "comparison"], label="Format", value="article", elem_id="info_format")
|
346 |
-
info_topic = gr.Textbox(label="Topic", placeholder="e.g., the basics of quantum physics", value="the impact of AI on healthcare", elem_id="info_topic", lines=2)
|
347 |
-
|
348 |
-
with gr.Group(visible=False, elem_id="custom_prompt_group") as custom_prompt_group:
|
349 |
-
custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter your full prompt here...", lines=5, elem_id="custom_prompt")
|
350 |
-
|
351 |
-
# Show/hide logic (using gr.update for better practice)
|
352 |
-
def update_text_gen_visibility(choice):
|
353 |
-
is_creative = choice == "Creative Writing"
|
354 |
-
is_info = choice == "Informational Writing"
|
355 |
-
is_custom = choice == "Custom Prompt"
|
356 |
-
return {
|
357 |
-
creative_options: gr.update(visible=is_creative),
|
358 |
-
info_options: gr.update(visible=is_info),
|
359 |
-
custom_prompt_group: gr.update(visible=is_custom)
|
360 |
-
}
|
361 |
-
text_gen_type.change(update_text_gen_visibility, inputs=text_gen_type, outputs=[creative_options, info_options, custom_prompt_group], queue=False)
|
362 |
-
|
363 |
-
# Parameters
|
364 |
text_gen_params = create_parameter_ui()
|
365 |
-
gr.Spacer
|
366 |
-
generate_text_btn = gr.Button("Generate Text", variant="primary"
|
367 |
-
|
368 |
-
# Output Column
|
369 |
with gr.Column(scale=1):
|
370 |
-
gr.Markdown("
|
371 |
-
text_output = gr.Textbox(label="Result", lines=25, interactive=False,
|
372 |
-
|
373 |
-
#
|
374 |
-
def
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
elif task_type == "custom":
|
391 |
-
kwargs["prompt"] = safe_value(custom_prompt_text, "Write something interesting.")
|
392 |
-
|
393 |
-
|
394 |
final_prompt = generate_prompt(task_type, **kwargs)
|
395 |
-
return generate_text(final_prompt,
|
396 |
-
|
397 |
-
generate_text_btn.click(
|
398 |
-
text_generation_handler,
|
399 |
-
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params],
|
400 |
-
outputs=text_output
|
401 |
-
)
|
402 |
|
403 |
# Examples
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
["Custom Prompt", "", "", "", "", "Write a short dialogue between a cat and a dog discussing their humans.", 512, 0.8, 0.95],
|
410 |
-
],
|
411 |
-
# Ensure the order matches the handler's inputs
|
412 |
-
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params[:3]], # Pass only the UI elements needed
|
413 |
-
outputs=text_output,
|
414 |
-
label="Try these examples...",
|
415 |
-
#fn=text_generation_handler # fn is deprecated, click event handles execution
|
416 |
-
)
|
417 |
|
418 |
|
419 |
# --- Brainstorming Tab ---
|
420 |
-
with gr.TabItem("🧠 Brainstorming"
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
gr.
|
446 |
-
|
447 |
-
|
448 |
-
["business", "eco-friendly subscription boxes", 768, 0.75, 0.9],
|
449 |
-
["creative", "themes for a fantasy novel", 512, 0.85, 0.95],
|
450 |
-
],
|
451 |
-
inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params[:3]],
|
452 |
-
outputs=brainstorm_output,
|
453 |
-
label="Try these examples...",
|
454 |
-
)
|
455 |
-
|
456 |
-
# --- Code Capabilities Tab ---
|
457 |
-
with gr.TabItem("💻 Code", id="tab_code"):
|
458 |
-
# Language mapping for syntax highlighting (defined once)
|
459 |
-
lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust", "Other": "plaintext"}
|
460 |
-
|
461 |
-
with gr.Tabs() as code_tabs:
|
462 |
-
# --- Code Generation ---
|
463 |
-
with gr.TabItem("Generate Code", id="subtab_code_gen"):
|
464 |
-
with gr.Row(equal_height=False):
|
465 |
-
# Input Column
|
466 |
with gr.Column(scale=1):
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
# Output Column
|
475 |
with gr.Column(scale=1):
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
result = generate_text(prompt, max_tokens, temp, top_p_val)
|
485 |
-
# Try to extract code block if markdown is used
|
486 |
-
if "```" in result:
|
487 |
parts = result.split("```")
|
488 |
if len(parts) >= 2:
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
# Update output language display based on dropdown
|
506 |
-
def update_code_language_display(lang):
|
507 |
-
return gr.Code.update(language=lang_map.get(lang, "plaintext")) # Use update method
|
508 |
-
|
509 |
-
code_language_gen.change(update_code_language_display, inputs=code_language_gen, outputs=code_output, queue=False)
|
510 |
-
code_gen_btn.click(code_gen_handler, inputs=[code_language_gen, code_task, *code_gen_params], outputs=code_output)
|
511 |
-
|
512 |
-
gr.Examples(
|
513 |
-
examples=[
|
514 |
-
["JavaScript", "function to validate an email address using regex", 768, 0.6, 0.9],
|
515 |
-
["SQL", "query to select users older than 30 from a 'users' table", 512, 0.5, 0.8],
|
516 |
-
["HTML", "basic structure for a personal portfolio website", 1024, 0.7, 0.9],
|
517 |
-
],
|
518 |
-
inputs=[code_language_gen, code_task, *code_gen_params[:3]],
|
519 |
-
outputs=code_output,
|
520 |
-
label="Try these examples...",
|
521 |
-
)
|
522 |
-
|
523 |
-
# --- Code Explanation ---
|
524 |
-
with gr.TabItem("Explain Code", id="subtab_code_explain"):
|
525 |
-
with gr.Row(equal_height=False):
|
526 |
-
# Input Column
|
527 |
-
with gr.Column(scale=1):
|
528 |
-
gr.Markdown("### Code Explanation Setup")
|
529 |
-
code_language_explain = gr.Dropdown(list(lang_map.keys()), label="Code Language (for context)", value="Python", elem_id="code_language_explain")
|
530 |
-
code_to_explain = gr.Code(label="Paste Code Here", language="python", lines=15, elem_id="code_to_explain")
|
531 |
-
explain_code_params = create_parameter_ui()
|
532 |
-
gr.Spacer(height=15)
|
533 |
-
explain_code_btn = gr.Button("Explain Code", variant="primary", elem_id="explain_code_btn")
|
534 |
-
|
535 |
-
# Output Column
|
536 |
-
with gr.Column(scale=1):
|
537 |
-
gr.Markdown("### Explanation")
|
538 |
-
code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="code_explanation", show_copy_button=True)
|
539 |
-
|
540 |
-
# Update code input language display
|
541 |
-
def update_explain_language_display(lang):
|
542 |
-
return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
543 |
-
code_language_explain.change(update_explain_language_display, inputs=code_language_explain, outputs=code_to_explain, queue=False)
|
544 |
-
|
545 |
-
# Handler
|
546 |
-
def explain_code_handler(language, code, max_tokens, temp, top_p_val):
|
547 |
-
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Add code here") # Handle potential dict input from gr.Code
|
548 |
-
language = safe_value(language, "code") # Use selected language in prompt
|
549 |
-
prompt = generate_prompt("code_explain", language=language, code=code_content)
|
550 |
-
return generate_text(prompt, max_tokens, temp, top_p_val)
|
551 |
-
|
552 |
-
explain_code_btn.click(explain_code_handler, inputs=[code_language_explain, code_to_explain, *explain_code_params], outputs=code_explanation)
|
553 |
-
|
554 |
-
# --- Code Debugging ---
|
555 |
-
with gr.TabItem("Debug Code", id="subtab_code_debug"):
|
556 |
-
with gr.Row(equal_height=False):
|
557 |
-
# Input Column
|
558 |
with gr.Column(scale=1):
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
value="def calculate_average(numbers):\n sum = 0\n for n in numbers:\n sum += n\n # Bug: potential division by zero if numbers is empty\n return sum / len(numbers)", # Example with potential bug
|
566 |
-
elem_id="code_to_debug"
|
567 |
-
)
|
568 |
-
debug_code_params = create_parameter_ui()
|
569 |
-
gr.Spacer(height=15)
|
570 |
-
debug_code_btn = gr.Button("Debug Code", variant="primary", elem_id="debug_code_btn")
|
571 |
-
|
572 |
-
# Output Column
|
573 |
with gr.Column(scale=1):
|
574 |
-
|
575 |
-
|
576 |
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
|
581 |
|
582 |
-
# Handler
|
583 |
-
def debug_code_handler(language, code, max_tokens, temp, top_p_val):
|
584 |
-
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Add potentially buggy code here")
|
585 |
-
language = safe_value(language, "code")
|
586 |
-
prompt = generate_prompt("code_debug", language=language, code=code_content)
|
587 |
-
return generate_text(prompt, max_tokens, temp, top_p_val)
|
588 |
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
with gr.Column(scale=1):
|
601 |
-
gr.Markdown("
|
602 |
-
summarize_text = gr.Textbox(label="Text to Summarize", placeholder="Paste long text
|
603 |
summarize_params = create_parameter_ui()
|
604 |
-
gr.Spacer
|
605 |
-
summarize_btn = gr.Button("Summarize Text", variant="primary"
|
606 |
-
# Output Column
|
607 |
with gr.Column(scale=1):
|
608 |
-
gr.Markdown("
|
609 |
-
summary_output = gr.Textbox(label="Result", lines=15, interactive=False,
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
# --- Question Answering ---
|
622 |
-
with gr.TabItem("Q & A", id="subtab_qa"):
|
623 |
-
with gr.Row(equal_height=False):
|
624 |
-
# Input Column
|
625 |
with gr.Column(scale=1):
|
626 |
-
gr.Markdown("
|
627 |
-
qa_text = gr.Textbox(label="Context Text", placeholder="Paste
|
628 |
-
qa_question = gr.Textbox(label="Question", placeholder="Ask
|
629 |
qa_params = create_parameter_ui()
|
630 |
-
gr.Spacer
|
631 |
-
qa_btn = gr.Button("Get Answer", variant="primary"
|
632 |
-
# Output Column
|
633 |
with gr.Column(scale=1):
|
634 |
-
gr.Markdown("
|
635 |
-
qa_output = gr.Textbox(label="Result", lines=10, interactive=False,
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
qa_btn.click(qa_handler, inputs=[qa_text, qa_question, *qa_params], outputs=qa_output)
|
647 |
-
|
648 |
-
# --- Translation ---
|
649 |
-
with gr.TabItem("Translate", id="subtab_translate"):
|
650 |
-
with gr.Row(equal_height=False):
|
651 |
-
# Input Column
|
652 |
with gr.Column(scale=1):
|
653 |
-
gr.Markdown("
|
654 |
-
translate_text = gr.Textbox(label="Text to Translate", placeholder="Enter text
|
655 |
-
target_lang = gr.Dropdown(
|
656 |
-
["French", "Spanish", "German", "Japanese", "Chinese", "Russian", "Arabic", "Hindi", "Portuguese", "Italian"],
|
657 |
-
label="Translate To", value="French", elem_id="target_lang"
|
658 |
-
)
|
659 |
translate_params = create_parameter_ui()
|
660 |
-
gr.Spacer
|
661 |
-
translate_btn = gr.Button("Translate Text", variant="primary"
|
662 |
-
# Output Column
|
663 |
with gr.Column(scale=1):
|
664 |
-
gr.Markdown("
|
665 |
-
translation_output = gr.Textbox(label="Result", lines=8, interactive=False,
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
686 |
with gr.Column(scale=1):
|
687 |
-
gr.Markdown("
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
gr.
|
693 |
-
content_btn = gr.Button("Generate Content", variant="primary", elem_id="content_btn")
|
694 |
with gr.Column(scale=1):
|
695 |
-
gr.Markdown("
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
710 |
with gr.Column(scale=1):
|
711 |
-
gr.Markdown("
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
gr.Spacer
|
716 |
-
|
717 |
with gr.Column(scale=1):
|
718 |
-
gr.Markdown("
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
with gr.TabItem("Document Editing", id="tab_edit"):
|
731 |
-
with gr.Row(equal_height=False):
|
732 |
with gr.Column(scale=1):
|
733 |
-
gr.Markdown("
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
gr.Spacer
|
738 |
-
|
739 |
with gr.Column(scale=1):
|
740 |
-
gr.Markdown("
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
prompt = generate_prompt("document_edit", text=text, edit_type=e_type)
|
747 |
-
# Editing might expand text, give it reasonable token count based on input + max_new
|
748 |
-
input_tokens_estimate = len(text.split()) # Rough estimate
|
749 |
-
max_tok = max(int(max_tok), input_tokens_estimate + 64) # Ensure enough room
|
750 |
-
return generate_text(prompt, max_tok, temp, top_p_val)
|
751 |
-
edit_btn.click(edit_handler, inputs=[edit_text, edit_type, *edit_params], outputs=edit_output)
|
752 |
-
|
753 |
-
|
754 |
-
# --- Classification ---
|
755 |
-
with gr.TabItem("Classification", id="tab_classify"):
|
756 |
-
with gr.Row(equal_height=False):
|
757 |
-
with gr.Column(scale=1):
|
758 |
-
gr.Markdown("### Classification Setup")
|
759 |
-
classify_text = gr.Textbox(label="Text to Classify", placeholder="Enter text...", lines=8, value="This new sci-fi movie explores themes of AI consciousness and interstellar travel.")
|
760 |
-
classify_categories = gr.Textbox(label="Categories (comma-separated)", placeholder="e.g., positive, negative, neutral", value="Technology, Entertainment, Science, Politics, Sports, Health")
|
761 |
-
classify_params = create_parameter_ui()
|
762 |
-
gr.Spacer(height=15)
|
763 |
-
classify_btn = gr.Button("Classify Text", variant="primary")
|
764 |
-
with gr.Column(scale=1):
|
765 |
-
gr.Markdown("### Classification Result")
|
766 |
-
classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False, show_copy_button=True)
|
767 |
-
|
768 |
-
def classify_handler(text, cats, max_tok, temp, top_p_val):
|
769 |
-
text = safe_value(text, "Text to classify needed.")
|
770 |
-
cats = safe_value(cats, "category1, category2")
|
771 |
-
# Classification usually needs short output
|
772 |
-
max_tok = min(max(int(max_tok), 16), 128) # Ensure int, constrain tightly
|
773 |
-
prompt = generate_prompt("classify", text=text, categories=cats)
|
774 |
-
# Often the model just outputs the category, so we might not need the prompt structure removal
|
775 |
-
raw_output = generate_text(prompt, max_tok, temp, top_p_val)
|
776 |
-
# Post-process to get just the category if possible
|
777 |
-
lines = raw_output.split('\n')
|
778 |
-
if lines:
|
779 |
-
last_line = lines[-1].strip()
|
780 |
-
# Check if the last line seems like one of the categories
|
781 |
-
possible_cats = [c.strip().lower() for c in cats.split(',')]
|
782 |
-
if last_line.lower() in possible_cats:
|
783 |
-
return last_line
|
784 |
-
# Fallback to raw output
|
785 |
-
return raw_output
|
786 |
-
|
787 |
-
classify_btn.click(classify_handler, inputs=[classify_text, classify_categories, *classify_params], outputs=classify_output)
|
788 |
-
|
789 |
-
|
790 |
-
# --- Data Extraction ---
|
791 |
-
with gr.TabItem("Data Extraction", id="tab_extract"):
|
792 |
-
with gr.Row(equal_height=False):
|
793 |
-
with gr.Column(scale=1):
|
794 |
-
gr.Markdown("### Extraction Setup")
|
795 |
-
extract_text = gr.Textbox(label="Source Text", placeholder="Paste text containing data...", lines=10, value="Order #12345 placed on 2024-03-15 by Jane Doe (jane.d@email.com). Total amount: $99.95. Shipping to 123 Main St, Anytown, USA.")
|
796 |
-
extract_data_points = gr.Textbox(label="Data to Extract (comma-separated)", placeholder="e.g., name, email, order number", value="order number, date, customer name, email, total amount, address")
|
797 |
-
extract_params = create_parameter_ui()
|
798 |
-
gr.Spacer(height=15)
|
799 |
-
extract_btn = gr.Button("Extract Data", variant="primary")
|
800 |
-
with gr.Column(scale=1):
|
801 |
-
gr.Markdown("### Extracted Data")
|
802 |
-
extract_output = gr.Textbox(label="Result (e.g., JSON or key-value pairs)", lines=10, interactive=False, show_copy_button=True)
|
803 |
|
804 |
-
def extract_handler(text, points, max_tok, temp, top_p_val):
|
805 |
-
text = safe_value(text, "Provide text for extraction.")
|
806 |
-
points = safe_value(points, "key information")
|
807 |
-
prompt = generate_prompt("data_extract", text=text, data_points=points)
|
808 |
-
return generate_text(prompt, max_tok, temp, top_p_val)
|
809 |
-
extract_btn.click(extract_handler, inputs=[extract_text, extract_data_points, *extract_params], outputs=extract_output)
|
810 |
|
|
|
|
|
811 |
|
812 |
-
# Define authentication handler AFTER tabs is defined
|
813 |
def handle_auth(token):
|
814 |
-
|
815 |
-
yield "⏳ Authenticating and loading model... Please wait.", gr.Tabs.update(visible=False)
|
816 |
-
# Call the actual model loading function
|
817 |
status_message, tabs_update = load_model(token)
|
818 |
yield status_message, tabs_update
|
819 |
|
820 |
-
#
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
queue=True # Run in queue for potentially long operation
|
826 |
-
)
|
827 |
-
|
828 |
-
# --- Footer ---
|
829 |
-
footer_status = gr.Markdown( # Use a separate Markdown for dynamic updates
|
830 |
-
f"""
|
831 |
-
---
|
832 |
-
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
833 |
-
<p>Powered by Google's Gemma models via Hugging Face 🤗 Transformers & Gradio.</p>
|
834 |
-
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
835 |
-
<p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p>
|
836 |
-
</div>
|
837 |
-
"""
|
838 |
-
)
|
839 |
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
846 |
-
<p>Powered by Google's Gemma models via Hugging Face 🤗 Transformers & Gradio.</p>
|
847 |
-
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
848 |
-
<p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p>
|
849 |
-
</div>
|
850 |
-
""")
|
851 |
-
auth_status.change(fn=update_footer_status, inputs=auth_status, outputs=footer_status, queue=False)
|
852 |
|
853 |
|
854 |
# --- Launch App ---
|
855 |
-
# Allow built-in theme switching
|
856 |
-
# Use queue() to handle multiple requests better
|
857 |
demo.launch(share=False, allowed_themes=["light", "dark"])
|
|
|
35 |
|
36 |
try:
|
37 |
# Try different model versions from smallest to largest
|
|
|
38 |
model_options = [
|
39 |
"google/gemma-2b-it",
|
40 |
"google/gemma-7b-it",
|
41 |
"google/gemma-2b",
|
42 |
"google/gemma-7b",
|
43 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Fallback
|
|
|
44 |
]
|
45 |
|
46 |
print(f"Attempting to load models with token starting with: {hf_token[:5]}...")
|
|
|
49 |
try:
|
50 |
print(f"\n--- Attempting to load model: {model_name} ---")
|
51 |
is_gemma = "gemma" in model_name.lower()
|
52 |
+
current_token = hf_token if is_gemma else None
|
|
|
53 |
|
|
|
54 |
print("Loading tokenizer...")
|
55 |
global_tokenizer = AutoTokenizer.from_pretrained(
|
56 |
model_name,
|
|
|
58 |
)
|
59 |
print("Tokenizer loaded successfully.")
|
60 |
|
|
|
61 |
print(f"Loading model {model_name}...")
|
62 |
global_model = AutoModelForCausalLM.from_pretrained(
|
63 |
model_name,
|
|
|
64 |
torch_dtype=torch.float16, # Using float16 for broader compatibility
|
65 |
+
device_map="auto",
|
66 |
token=current_token
|
67 |
)
|
68 |
print(f"Model {model_name} loaded successfully!")
|
|
|
70 |
model_loaded = True
|
71 |
loaded_model_name = model_name
|
72 |
loaded_successfully = True
|
73 |
+
tabs_update = gr.Tabs.update(visible=True)
|
74 |
+
status_msg = f"✅ Model '{model_name}' loaded successfully!"
|
75 |
+
if "tinyllama" in model_name.lower():
|
76 |
+
status_msg = f"✅ Fallback model '{model_name}' loaded successfully! Limited capabilities compared to Gemma."
|
77 |
+
return status_msg, tabs_update
|
78 |
|
79 |
except ImportError as import_err:
|
80 |
+
print(f"Import Error loading {model_name}: {import_err}. Check dependencies (e.g., bitsandbytes, accelerate).")
|
81 |
+
continue
|
|
|
82 |
except Exception as specific_e:
|
83 |
print(f"Failed to load {model_name}: {specific_e}")
|
84 |
+
if "401 Client Error" in str(specific_e) or "requires you to be logged in" in str(specific_e) and is_gemma:
|
|
|
85 |
print("Authentication error likely. Check token and license agreement.")
|
|
|
|
|
|
|
|
|
|
|
86 |
continue
|
87 |
|
|
|
88 |
if not loaded_successfully:
|
89 |
model_loaded = False
|
90 |
loaded_model_name = "None"
|
91 |
print("Could not load any model version.")
|
92 |
+
return "❌ Could not load any model. Please check your token, license acceptance, dependencies, and network connection.", initial_tabs_update
|
93 |
|
94 |
except Exception as e:
|
95 |
model_loaded = False
|
|
|
97 |
error_msg = str(e)
|
98 |
print(f"Error in load_model: {error_msg}")
|
99 |
traceback.print_exc()
|
|
|
100 |
if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg :
|
101 |
+
return "❌ Authentication failed. Check token/license.", initial_tabs_update
|
102 |
else:
|
103 |
+
return f"❌ Unexpected error during model loading: {error_msg}", initial_tabs_update
|
104 |
|
105 |
|
106 |
def generate_prompt(task_type, **kwargs):
|
107 |
"""Generate appropriate prompts based on task type and parameters"""
|
|
|
108 |
prompts = {
|
109 |
"creative": "Write a {style} about {topic}. Be creative and engaging.",
|
110 |
"informational": "Write an {format_type} about {topic}. Be clear, factual, and informative.",
|
|
|
122 |
"classify": "Classify the following text into one of these categories: {categories}\n\nText: {text}\n\nCategory:",
|
123 |
"data_extract": "Extract the following data points ({data_points}) from the text below:\n\nText: {text}\n\nExtracted Data:",
|
124 |
}
|
|
|
125 |
prompt_template = prompts.get(task_type)
|
126 |
if prompt_template:
|
127 |
try:
|
|
|
|
|
128 |
keys_in_template = [k[1:-1] for k in prompt_template.split('{') if '}' in k for k in [k.split('}')[0]]]
|
129 |
+
final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in keys_in_template}
|
130 |
+
final_kwargs.update(kwargs) # Add extras
|
|
|
|
|
|
|
131 |
return prompt_template.format(**final_kwargs)
|
132 |
except KeyError as e:
|
133 |
print(f"Warning: Missing key for prompt template '{task_type}': {e}")
|
134 |
+
return kwargs.get("prompt", f"Generate text based on: {kwargs}")
|
135 |
else:
|
|
|
136 |
return kwargs.get("prompt", "Generate text based on the input.")
|
137 |
|
138 |
|
|
|
146 |
print(f"Prompt (start): {prompt[:150]}...")
|
147 |
|
148 |
if not model_loaded or global_model is None or global_tokenizer is None:
|
|
|
149 |
return "⚠️ Model not loaded. Please authenticate first."
|
|
|
150 |
if not prompt:
|
151 |
return "⚠️ Please enter a prompt or configure a task."
|
152 |
|
153 |
try:
|
154 |
+
chat_prompt = prompt # Default to raw prompt
|
|
|
155 |
if loaded_model_name and ("it" in loaded_model_name.lower() or "instruct" in loaded_model_name.lower() or "chat" in loaded_model_name.lower()):
|
|
|
|
|
156 |
if "gemma" in loaded_model_name.lower():
|
157 |
+
# Use Gemma's specific format
|
158 |
chat_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
159 |
+
elif "tinyllama" in loaded_model_name.lower():
|
160 |
+
# Use TinyLlama's chat format
|
161 |
+
chat_prompt = f"<|system|>\nYou are a friendly chatbot.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
|
162 |
else: # Generic instruction format
|
163 |
chat_prompt = f"User: {prompt}\nAssistant:"
|
|
|
|
|
|
|
164 |
|
165 |
inputs = global_tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=True).to(global_model.device)
|
166 |
input_length = inputs.input_ids.shape[1]
|
167 |
print(f"Input token length: {input_length}")
|
168 |
|
169 |
+
effective_max_new_tokens = min(int(max_new_tokens), 2048)
|
170 |
+
|
171 |
+
# Handle potential None for eos_token_id
|
172 |
+
eos_token_id = global_tokenizer.eos_token_id
|
173 |
+
if eos_token_id is None:
|
174 |
+
print("Warning: eos_token_id is None, using default 50256.")
|
175 |
+
eos_token_id = 50256 # A common default EOS token ID
|
176 |
|
177 |
generation_args = {
|
178 |
"input_ids": inputs.input_ids,
|
179 |
+
"attention_mask": inputs.attention_mask,
|
180 |
"max_new_tokens": effective_max_new_tokens,
|
181 |
"do_sample": True,
|
182 |
+
"temperature": float(temperature),
|
183 |
+
"top_p": float(top_p),
|
184 |
+
"pad_token_id": eos_token_id # Use determined EOS or default
|
185 |
}
|
186 |
|
187 |
print(f"Generation args: {generation_args}")
|
188 |
|
189 |
+
with torch.no_grad():
|
|
|
190 |
outputs = global_model.generate(**generation_args)
|
191 |
|
|
|
192 |
generated_ids = outputs[0, input_length:]
|
193 |
generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True)
|
194 |
|
195 |
print(f"Generated text length: {len(generated_text)}")
|
196 |
print(f"Generated text (start): {generated_text[:150]}...")
|
197 |
+
return generated_text.strip()
|
198 |
|
199 |
except Exception as e:
|
200 |
error_msg = str(e)
|
201 |
print(f"Generation error: {error_msg}")
|
|
|
202 |
traceback.print_exc()
|
|
|
203 |
if "CUDA out of memory" in error_msg:
|
204 |
+
return f"❌ Error: CUDA out of memory. Try reducing 'Max New Tokens' or using a smaller model."
|
205 |
+
elif "probability tensor contains nan" in error_msg or "invalid value encountered" in error_msg:
|
206 |
+
return f"❌ Error: Generation failed (invalid probability). Try adjusting Temperature/Top-P or modifying the prompt."
|
207 |
else:
|
208 |
+
return f"❌ Error during text generation: {error_msg}"
|
209 |
+
|
210 |
+
# --- UI Components & Layout ---
|
211 |
|
|
|
212 |
def create_parameter_ui():
|
213 |
with gr.Accordion("✨ Generation Parameters", open=False):
|
214 |
with gr.Row():
|
215 |
+
max_new_tokens = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max New Tokens", info="Max tokens to generate.")
|
216 |
+
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature", info="Controls randomness.")
|
217 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
return [max_new_tokens, temperature, top_p]
|
219 |
|
220 |
+
# Language map (defined once)
|
221 |
+
lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust", "Other": "plaintext"}
|
222 |
+
|
223 |
# --- Gradio Interface ---
|
|
|
224 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, title="Gemma Capabilities Demo") as demo:
|
225 |
|
226 |
# Header
|
227 |
gr.Markdown(
|
228 |
"""
|
229 |
+
<div style="text-align: center; margin-bottom: 20px;"><h1><span style="font-size: 1.5em;">🤖</span> Gemma Capabilities Demo</h1>
|
230 |
+
<p>Explore text generation with Google's Gemma models (or a fallback).</p>
|
231 |
+
<p style="font-size: 0.9em;"><a href="https://huggingface.co/google/gemma-7b-it" target="_blank">[Accept Gemma License Here]</a></p></div>"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
)
|
233 |
|
234 |
+
# --- Authentication ---
|
235 |
+
with gr.Group(): # Removed variant="panel"
|
236 |
+
gr.Markdown("### 🔑 Authentication")
|
|
|
237 |
with gr.Row():
|
238 |
with gr.Column(scale=4):
|
239 |
+
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Paste token (hf_...)", type="password", value=DEFAULT_HF_TOKEN, info="Needed for Gemma models.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
with gr.Column(scale=1, min_width=150):
|
241 |
+
auth_button = gr.Button("Load Model", variant="primary")
|
242 |
+
auth_status = gr.Markdown("ℹ️ Enter token & click 'Load Model'. May take time.")
|
|
|
|
|
243 |
gr.Markdown(
|
244 |
+
"**Token Info:** Get from [HF Settings](https://huggingface.co/settings/tokens) (read access). Ensure Gemma license is accepted.",
|
245 |
+
elem_id="token-info" # Optional ID for styling if needed later
|
|
|
|
|
|
|
|
|
246 |
)
|
247 |
|
248 |
+
# --- Main Content Tabs ---
|
|
|
|
|
249 |
with gr.Tabs(elem_id="main_tabs", visible=False) as tabs:
|
250 |
|
251 |
# --- Text Generation Tab ---
|
252 |
+
with gr.TabItem("📝 Creative & Informational"):
|
253 |
+
with gr.Row():
|
|
|
254 |
with gr.Column(scale=1):
|
255 |
+
gr.Markdown("#### Configure Task")
|
256 |
+
text_gen_type = gr.Radio(["Creative Writing", "Informational Writing", "Custom Prompt"], label="Writing Type", value="Creative Writing")
|
257 |
+
with gr.Group(visible=True) as creative_options:
|
258 |
+
style = gr.Dropdown(["short story", "poem", "script", "song lyrics", "joke", "dialogue"], label="Style", value="short story")
|
259 |
+
creative_topic = gr.Textbox(label="Topic", placeholder="e.g., a lonely astronaut", value="a robot discovering music", lines=2)
|
260 |
+
with gr.Group(visible=False) as info_options:
|
261 |
+
format_type = gr.Dropdown(["article", "summary", "explanation", "report", "comparison"], label="Format", value="article")
|
262 |
+
info_topic = gr.Textbox(label="Topic", placeholder="e.g., quantum physics basics", value="AI impact on healthcare", lines=2)
|
263 |
+
with gr.Group(visible=False) as custom_prompt_group:
|
264 |
+
custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter full prompt...", lines=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
text_gen_params = create_parameter_ui()
|
266 |
+
# Removed gr.Spacer
|
267 |
+
generate_text_btn = gr.Button("Generate Text", variant="primary")
|
|
|
|
|
268 |
with gr.Column(scale=1):
|
269 |
+
gr.Markdown("#### Generated Output")
|
270 |
+
text_output = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
271 |
+
|
272 |
+
# Visibility logic
|
273 |
+
def update_text_gen_visibility(choice):
|
274 |
+
return { creative_options: gr.update(visible=choice == "Creative Writing"),
|
275 |
+
info_options: gr.update(visible=choice == "Informational Writing"),
|
276 |
+
custom_prompt_group: gr.update(visible=choice == "Custom Prompt") }
|
277 |
+
text_gen_type.change(update_text_gen_visibility, text_gen_type, [creative_options, info_options, custom_prompt_group], queue=False)
|
278 |
+
|
279 |
+
# Click handler
|
280 |
+
def text_gen_click(gen_type, style, c_topic, fmt_type, i_topic, custom_pr, *params):
|
281 |
+
task_map = {"Creative Writing": ("creative", {"style": style, "topic": c_topic}),
|
282 |
+
"Informational Writing": ("informational", {"format_type": fmt_type, "topic": i_topic}),
|
283 |
+
"Custom Prompt": ("custom", {"prompt": custom_pr})}
|
284 |
+
task_type, kwargs = task_map.get(gen_type, ("custom", {"prompt": custom_pr}))
|
285 |
+
# Apply safe_value inside handler where needed
|
286 |
+
if task_type == "creative": kwargs = {"style": safe_value(style, "story"), "topic": safe_value(c_topic, "[topic]")}
|
287 |
+
elif task_type == "informational": kwargs = {"format_type": safe_value(fmt_type, "article"), "topic": safe_value(i_topic, "[topic]")}
|
288 |
+
else: kwargs = {"prompt": safe_value(custom_pr, "Write something.")}
|
|
|
|
|
|
|
|
|
289 |
final_prompt = generate_prompt(task_type, **kwargs)
|
290 |
+
return generate_text(final_prompt, *params)
|
291 |
+
generate_text_btn.click(text_gen_click, [text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params], text_output)
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
# Examples
|
294 |
+
gr.Examples( examples=[ ["Creative Writing", "poem", "sound of rain", "", "", "", 512, 0.7, 0.9],
|
295 |
+
["Informational Writing", "", "", "explanation", "photosynthesis", "", 768, 0.6, 0.9],
|
296 |
+
["Custom Prompt", "", "", "", "", "Dialogue: cat and dog discuss humans.", 512, 0.8, 0.95] ],
|
297 |
+
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params[:3]], # Pass UI elements
|
298 |
+
outputs=text_output, label="Try examples...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
|
301 |
# --- Brainstorming Tab ---
|
302 |
+
with gr.TabItem("🧠 Brainstorming"):
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Column(scale=1):
|
305 |
+
gr.Markdown("#### Setup")
|
306 |
+
brainstorm_category = gr.Dropdown(["project", "business", "creative", "solution", "content", "feature", "product name"], label="Category", value="project")
|
307 |
+
brainstorm_topic = gr.Textbox(label="Topic/Problem", placeholder="e.g., reducing plastic waste", value="unique mobile app ideas", lines=3)
|
308 |
+
brainstorm_params = create_parameter_ui()
|
309 |
+
# Removed gr.Spacer
|
310 |
+
brainstorm_btn = gr.Button("Generate Ideas", variant="primary")
|
311 |
+
with gr.Column(scale=1):
|
312 |
+
gr.Markdown("#### Generated Ideas")
|
313 |
+
brainstorm_output = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
314 |
+
|
315 |
+
def brainstorm_click(category, topic, *params):
|
316 |
+
prompt = generate_prompt("brainstorm", category=safe_value(category, "project"), topic=safe_value(topic, "ideas"))
|
317 |
+
return generate_text(prompt, *params)
|
318 |
+
brainstorm_btn.click(brainstorm_click, [brainstorm_category, brainstorm_topic, *brainstorm_params], brainstorm_output)
|
319 |
+
gr.Examples([ ["solution", "engaging online learning", 768, 0.8, 0.9],
|
320 |
+
["business", "eco-friendly subscription boxes", 768, 0.75, 0.9],
|
321 |
+
["creative", "fantasy novel themes", 512, 0.85, 0.95] ],
|
322 |
+
inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params[:3]], outputs=brainstorm_output, label="Try examples...")
|
323 |
+
|
324 |
+
|
325 |
+
# --- Code Tab ---
|
326 |
+
with gr.TabItem("💻 Code"):
|
327 |
+
with gr.Tabs():
|
328 |
+
with gr.TabItem("Generate"):
|
329 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
with gr.Column(scale=1):
|
331 |
+
gr.Markdown("#### Setup")
|
332 |
+
code_lang_gen = gr.Dropdown(list(lang_map.keys())[:-1], label="Language", value="Python")
|
333 |
+
code_task = gr.Textbox(label="Task", placeholder="e.g., function for factorial", value="Python class for calculator", lines=4)
|
334 |
+
code_gen_params = create_parameter_ui()
|
335 |
+
# Removed gr.Spacer
|
336 |
+
code_gen_btn = gr.Button("Generate Code", variant="primary")
|
|
|
|
|
337 |
with gr.Column(scale=1):
|
338 |
+
gr.Markdown("#### Generated Code")
|
339 |
+
code_output = gr.Code(label="Result", language="python", lines=25, interactive=False)
|
340 |
+
|
341 |
+
def gen_code_click(lang, task, *params):
|
342 |
+
prompt = generate_prompt("code_generate", language=safe_value(lang, "Python"), task=safe_value(task, "hello world"))
|
343 |
+
result = generate_text(prompt, *params)
|
344 |
+
# Basic code block extraction
|
345 |
+
if "```" in result:
|
|
|
|
|
|
|
346 |
parts = result.split("```")
|
347 |
if len(parts) >= 2:
|
348 |
+
block = parts[1]
|
349 |
+
if '\n' in block: first_line, rest = block.split('\n', 1); return rest.strip() if first_line.strip().lower() == lang.lower() else block.strip()
|
350 |
+
else: return block.strip()
|
351 |
+
return result.strip()
|
352 |
+
def update_gen_lang_display(lang): return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
353 |
+
code_lang_gen.change(update_gen_lang_display, code_lang_gen, code_output, queue=False)
|
354 |
+
code_gen_btn.click(gen_code_click, [code_lang_gen, code_task, *code_gen_params], code_output)
|
355 |
+
gr.Examples([ ["JavaScript", "email validation regex function", 768, 0.6, 0.9],
|
356 |
+
["SQL", "select users > 30 yrs old", 512, 0.5, 0.8],
|
357 |
+
["HTML", "basic portfolio structure", 1024, 0.7, 0.9] ],
|
358 |
+
inputs=[code_lang_gen, code_task, *code_gen_params[:3]], outputs=code_output, label="Try examples...")
|
359 |
+
|
360 |
+
with gr.TabItem("Explain"):
|
361 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
362 |
with gr.Column(scale=1):
|
363 |
+
gr.Markdown("#### Setup")
|
364 |
+
code_lang_explain = gr.Dropdown(list(lang_map.keys()), label="Language", value="Python")
|
365 |
+
code_to_explain = gr.Code(label="Code to Explain", language="python", lines=15)
|
366 |
+
explain_code_params = create_parameter_ui()
|
367 |
+
# Removed gr.Spacer
|
368 |
+
explain_code_btn = gr.Button("Explain Code", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
with gr.Column(scale=1):
|
370 |
+
gr.Markdown("#### Explanation")
|
371 |
+
code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
372 |
|
373 |
+
def explain_code_click(lang, code, *params):
|
374 |
+
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Empty code")
|
375 |
+
prompt = generate_prompt("code_explain", language=safe_value(lang, "code"), code=code_content)
|
376 |
+
return generate_text(prompt, *params)
|
377 |
+
def update_explain_lang_display(lang): return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
378 |
+
code_lang_explain.change(update_explain_lang_display, code_lang_explain, code_to_explain, queue=False)
|
379 |
+
explain_code_btn.click(explain_code_click, [code_lang_explain, code_to_explain, *explain_code_params], code_explanation)
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
+
with gr.TabItem("Debug"):
|
383 |
+
with gr.Row():
|
384 |
+
with gr.Column(scale=1):
|
385 |
+
gr.Markdown("#### Setup")
|
386 |
+
code_lang_debug = gr.Dropdown(list(lang_map.keys()), label="Language", value="Python")
|
387 |
+
code_to_debug = gr.Code(label="Buggy Code", language="python", lines=15, value="def avg(nums):\n # Potential div by zero\n return sum(nums)/len(nums)")
|
388 |
+
debug_code_params = create_parameter_ui()
|
389 |
+
# Removed gr.Spacer
|
390 |
+
debug_code_btn = gr.Button("Debug Code", variant="primary")
|
391 |
+
with gr.Column(scale=1):
|
392 |
+
gr.Markdown("#### Debugging Analysis")
|
393 |
+
debug_result = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
394 |
+
|
395 |
+
def debug_code_click(lang, code, *params):
|
396 |
+
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Empty code")
|
397 |
+
prompt = generate_prompt("code_debug", language=safe_value(lang, "code"), code=code_content)
|
398 |
+
return generate_text(prompt, *params)
|
399 |
+
def update_debug_lang_display(lang): return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
400 |
+
code_lang_debug.change(update_debug_lang_display, code_lang_debug, code_to_debug, queue=False)
|
401 |
+
debug_code_btn.click(debug_code_click, [code_lang_debug, code_to_debug, *debug_code_params], debug_result)
|
402 |
+
|
403 |
+
|
404 |
+
# --- Comprehension Tab ---
|
405 |
+
with gr.TabItem("📚 Comprehension"):
|
406 |
+
with gr.Tabs():
|
407 |
+
with gr.TabItem("Summarize"):
|
408 |
+
with gr.Row():
|
409 |
with gr.Column(scale=1):
|
410 |
+
gr.Markdown("#### Setup")
|
411 |
+
summarize_text = gr.Textbox(label="Text to Summarize", lines=15, placeholder="Paste long text...")
|
412 |
summarize_params = create_parameter_ui()
|
413 |
+
# Removed gr.Spacer
|
414 |
+
summarize_btn = gr.Button("Summarize Text", variant="primary")
|
|
|
415 |
with gr.Column(scale=1):
|
416 |
+
gr.Markdown("#### Summary")
|
417 |
+
summary_output = gr.Textbox(label="Result", lines=15, interactive=False, show_copy_button=True)
|
418 |
+
def summarize_click(text, *params):
|
419 |
+
prompt = generate_prompt("summarize", text=safe_value(text, "[empty text]"))
|
420 |
+
# Adjust max tokens for summary specifically if needed
|
421 |
+
p_list = list(params); p_list[0] = min(max(int(p_list[0]), 64), 512)
|
422 |
+
return generate_text(prompt, *p_list)
|
423 |
+
summarize_btn.click(summarize_click, [summarize_text, *summarize_params], summary_output)
|
424 |
+
|
425 |
+
|
426 |
+
with gr.TabItem("Q & A"):
|
427 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
428 |
with gr.Column(scale=1):
|
429 |
+
gr.Markdown("#### Setup")
|
430 |
+
qa_text = gr.Textbox(label="Context Text", lines=10, placeholder="Paste text containing answer...")
|
431 |
+
qa_question = gr.Textbox(label="Question", placeholder="Ask question about text...")
|
432 |
qa_params = create_parameter_ui()
|
433 |
+
# Removed gr.Spacer
|
434 |
+
qa_btn = gr.Button("Get Answer", variant="primary")
|
|
|
435 |
with gr.Column(scale=1):
|
436 |
+
gr.Markdown("#### Answer")
|
437 |
+
qa_output = gr.Textbox(label="Result", lines=10, interactive=False, show_copy_button=True)
|
438 |
+
def qa_click(text, question, *params):
|
439 |
+
prompt = generate_prompt("qa", text=safe_value(text, "[context]"), question=safe_value(question,"[question]"))
|
440 |
+
p_list = list(params); p_list[0] = min(max(int(p_list[0]), 32), 256)
|
441 |
+
return generate_text(prompt, *p_list)
|
442 |
+
qa_btn.click(qa_click, [qa_text, qa_question, *qa_params], qa_output)
|
443 |
+
|
444 |
+
|
445 |
+
with gr.TabItem("Translate"):
|
446 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
with gr.Column(scale=1):
|
448 |
+
gr.Markdown("#### Setup")
|
449 |
+
translate_text = gr.Textbox(label="Text to Translate", lines=8, placeholder="Enter text...")
|
450 |
+
target_lang = gr.Dropdown(["French", "Spanish", "German", "Japanese", "Chinese", "Russian", "Arabic", "Hindi", "Portuguese", "Italian"], label="Translate To", value="French")
|
|
|
|
|
|
|
451 |
translate_params = create_parameter_ui()
|
452 |
+
# Removed gr.Spacer
|
453 |
+
translate_btn = gr.Button("Translate Text", variant="primary")
|
|
|
454 |
with gr.Column(scale=1):
|
455 |
+
gr.Markdown("#### Translation")
|
456 |
+
translation_output = gr.Textbox(label="Result", lines=8, interactive=False, show_copy_button=True)
|
457 |
+
def translate_click(text, lang, *params):
|
458 |
+
prompt = generate_prompt("translate", text=safe_value(text,"[text]"), target_lang=safe_value(lang,"French"))
|
459 |
+
p_list = list(params); p_list[0] = max(int(p_list[0]), 64)
|
460 |
+
return generate_text(prompt, *p_list)
|
461 |
+
translate_btn.click(translate_click, [translate_text, target_lang, *translate_params], translation_output)
|
462 |
+
|
463 |
+
|
464 |
+
# --- More Tasks Tab ---
|
465 |
+
with gr.TabItem("🛠️ More Tasks"):
|
466 |
+
with gr.Tabs():
|
467 |
+
with gr.TabItem("Content Creation"):
|
468 |
+
with gr.Row():
|
469 |
+
with gr.Column(scale=1):
|
470 |
+
gr.Markdown("#### Setup")
|
471 |
+
content_type = gr.Dropdown(["blog post outline", "social media post (Twitter)", "social media post (LinkedIn)", "marketing email subject line", "product description", "press release intro"], label="Content Type", value="blog post outline")
|
472 |
+
content_topic = gr.Textbox(label="Topic", value="sustainable travel tips", lines=2)
|
473 |
+
content_audience = gr.Textbox(label="Audience", value="eco-conscious millennials")
|
474 |
+
content_params = create_parameter_ui()
|
475 |
+
# Removed gr.Spacer
|
476 |
+
content_btn = gr.Button("Generate Content", variant="primary")
|
477 |
+
with gr.Column(scale=1):
|
478 |
+
gr.Markdown("#### Generated Content")
|
479 |
+
content_output = gr.Textbox(label="Result", lines=20, interactive=False, show_copy_button=True)
|
480 |
+
def content_click(c_type, topic, audience, *params):
|
481 |
+
prompt = generate_prompt("content_creation", content_type=safe_value(c_type,"text"), topic=safe_value(topic,"[topic]"), audience=safe_value(audience,"[audience]"))
|
482 |
+
return generate_text(prompt, *params)
|
483 |
+
content_btn.click(content_click, [content_type, content_topic, content_audience, *content_params], content_output)
|
484 |
+
|
485 |
+
with gr.TabItem("Email Drafting"):
|
486 |
+
with gr.Row():
|
487 |
with gr.Column(scale=1):
|
488 |
+
gr.Markdown("#### Setup")
|
489 |
+
email_type = gr.Dropdown(["job inquiry", "meeting request", "follow-up", "thank you", "support response", "sales outreach"], label="Email Type", value="meeting request")
|
490 |
+
email_context = gr.Textbox(label="Context/Points", lines=5, value="Request meeting next week re: project X. Suggest Tue/Wed afternoon.")
|
491 |
+
email_params = create_parameter_ui()
|
492 |
+
# Removed gr.Spacer
|
493 |
+
email_btn = gr.Button("Generate Draft", variant="primary")
|
|
|
494 |
with gr.Column(scale=1):
|
495 |
+
gr.Markdown("#### Generated Draft")
|
496 |
+
email_output = gr.Textbox(label="Result", lines=20, interactive=False, show_copy_button=True)
|
497 |
+
def email_click(e_type, context, *params):
|
498 |
+
prompt = generate_prompt("email_draft", email_type=safe_value(e_type,"email"), context=safe_value(context,"[context]"))
|
499 |
+
return generate_text(prompt, *params)
|
500 |
+
email_btn.click(email_click, [email_type, email_context, *email_params], email_output)
|
501 |
+
|
502 |
+
with gr.TabItem("Doc Editing"):
|
503 |
+
with gr.Row():
|
504 |
+
with gr.Column(scale=1):
|
505 |
+
gr.Markdown("#### Setup")
|
506 |
+
edit_text = gr.Textbox(label="Text to Edit", lines=10, placeholder="Paste text...")
|
507 |
+
edit_type = gr.Dropdown(["improve clarity", "fix grammar/spelling", "make concise", "make formal", "make casual", "simplify"], label="Improve For", value="improve clarity")
|
508 |
+
edit_params = create_parameter_ui()
|
509 |
+
# Removed gr.Spacer
|
510 |
+
edit_btn = gr.Button("Edit Text", variant="primary")
|
511 |
+
with gr.Column(scale=1):
|
512 |
+
gr.Markdown("#### Edited Text")
|
513 |
+
edit_output = gr.Textbox(label="Result", lines=10, interactive=False, show_copy_button=True)
|
514 |
+
def edit_click(text, e_type, *params):
|
515 |
+
prompt = generate_prompt("document_edit", text=safe_value(text,"[text]"), edit_type=safe_value(e_type,"clarity"))
|
516 |
+
p_list = list(params); input_tokens = len(safe_value(text,"").split()); p_list[0] = max(int(p_list[0]), input_tokens + 64)
|
517 |
+
return generate_text(prompt, *p_list)
|
518 |
+
edit_btn.click(edit_click, [edit_text, edit_type, *edit_params], edit_output)
|
519 |
+
|
520 |
+
with gr.TabItem("Classification"):
|
521 |
+
with gr.Row():
|
522 |
with gr.Column(scale=1):
|
523 |
+
gr.Markdown("#### Setup")
|
524 |
+
classify_text = gr.Textbox(label="Text to Classify", lines=8, value="Sci-fi movie explores AI consciousness.")
|
525 |
+
classify_categories = gr.Textbox(label="Categories (comma-sep)", value="Tech, Entertainment, Science, Politics")
|
526 |
+
classify_params = create_parameter_ui()
|
527 |
+
# Removed gr.Spacer
|
528 |
+
classify_btn = gr.Button("Classify Text", variant="primary")
|
529 |
with gr.Column(scale=1):
|
530 |
+
gr.Markdown("#### Classification")
|
531 |
+
classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False, show_copy_button=True)
|
532 |
+
def classify_click(text, cats, *params):
|
533 |
+
prompt = generate_prompt("classify", text=safe_value(text,"[text]"), categories=safe_value(cats,"cat1, cat2"))
|
534 |
+
p_list = list(params); p_list[0] = min(max(int(p_list[0]), 16), 128)
|
535 |
+
raw = generate_text(prompt, *p_list)
|
536 |
+
# Basic post-processing attempt
|
537 |
+
lines = raw.split('\n'); last = lines[-1].strip(); possible = [c.strip().lower() for c in cats.split(',')]; return last if last.lower() in possible else raw
|
538 |
+
classify_btn.click(classify_click, [classify_text, classify_categories, *classify_params], classify_output)
|
539 |
+
|
540 |
+
with gr.TabItem("Data Extraction"):
|
541 |
+
with gr.Row():
|
|
|
|
|
542 |
with gr.Column(scale=1):
|
543 |
+
gr.Markdown("#### Setup")
|
544 |
+
extract_text = gr.Textbox(label="Source Text", lines=10, value="Order #123 by Jane (j@ex.com). Total: $99. Shipped: 123 Main St.")
|
545 |
+
extract_data_points = gr.Textbox(label="Data Points (comma-sep)", value="order num, name, email, total, address")
|
546 |
+
extract_params = create_parameter_ui()
|
547 |
+
# Removed gr.Spacer
|
548 |
+
extract_btn = gr.Button("Extract Data", variant="primary")
|
549 |
with gr.Column(scale=1):
|
550 |
+
gr.Markdown("#### Extracted Data")
|
551 |
+
extract_output = gr.Textbox(label="Result (JSON or Key-Value)", lines=10, interactive=False, show_copy_button=True)
|
552 |
+
def extract_click(text, points, *params):
|
553 |
+
prompt = generate_prompt("data_extract", text=safe_value(text,"[text]"), data_points=safe_value(points,"info"))
|
554 |
+
return generate_text(prompt, *params)
|
555 |
+
extract_btn.click(extract_click, [extract_text, extract_data_points, *extract_params], extract_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
+
# --- Authentication Handler & Footer ---
|
559 |
+
footer_status = gr.Markdown(f"...", elem_id="footer-status-md") # Placeholder
|
560 |
|
|
|
561 |
def handle_auth(token):
|
562 |
+
yield "⏳ Authenticating & loading model...", gr.Tabs.update(visible=False)
|
|
|
|
|
563 |
status_message, tabs_update = load_model(token)
|
564 |
yield status_message, tabs_update
|
565 |
|
566 |
+
def update_footer_status(status_text): # Updates footer based on global state
|
567 |
+
return gr.Markdown.update(value=f"""
|
568 |
+
<hr><div style="text-align: center; font-size: 0.9em; color: #777;">
|
569 |
+
<p>Powered by Hugging Face 🤗 Transformers & Gradio. Model: <strong>{loaded_model_name if model_loaded else 'None'}</strong>.</p>
|
570 |
+
<p>Review outputs carefully. Models may generate inaccurate information.</p></div>""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
571 |
|
572 |
+
auth_button.click(handle_auth, hf_token, [auth_status, tabs], queue=True)
|
573 |
+
# Update footer whenever auth status text changes
|
574 |
+
auth_status.change(update_footer_status, auth_status, footer_status, queue=False)
|
575 |
+
# Initial footer update on load
|
576 |
+
demo.load(update_footer_status, auth_status, footer_status, queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
|
578 |
|
579 |
# --- Launch App ---
|
|
|
|
|
580 |
demo.launch(share=False, allowed_themes=["light", "dark"])
|