Spaces:
Running
Running
Commit
Β·
13af073
1
Parent(s):
f39fe5c
ONS
Browse files
app.py
CHANGED
@@ -25,10 +25,13 @@ def load_model(hf_token):
|
|
25 |
"""Load the model with the provided token"""
|
26 |
global global_model, global_tokenizer, model_loaded, loaded_model_name
|
27 |
|
|
|
|
|
|
|
28 |
if not hf_token:
|
29 |
model_loaded = False
|
30 |
loaded_model_name = "None"
|
31 |
-
return "β οΈ Please enter your Hugging Face token to use the model.",
|
32 |
|
33 |
try:
|
34 |
# Try different model versions from smallest to largest
|
@@ -63,7 +66,8 @@ def load_model(hf_token):
|
|
63 |
print(f"Loading model {model_name}...")
|
64 |
global_model = AutoModelForCausalLM.from_pretrained(
|
65 |
model_name,
|
66 |
-
torch_dtype=torch.bfloat16, # Use bfloat16 for better performance/compatibility if available
|
|
|
67 |
device_map="auto", # Let HF decide device placement
|
68 |
token=current_token
|
69 |
)
|
@@ -72,11 +76,16 @@ def load_model(hf_token):
|
|
72 |
model_loaded = True
|
73 |
loaded_model_name = model_name
|
74 |
loaded_successfully = True
|
|
|
75 |
if is_fallback:
|
76 |
-
return f"β
Fallback model '{model_name}' loaded successfully! Limited capabilities compared to Gemma.",
|
77 |
else:
|
78 |
-
return f"β
Model '{model_name}' loaded successfully!",
|
79 |
|
|
|
|
|
|
|
|
|
80 |
except Exception as specific_e:
|
81 |
print(f"Failed to load {model_name}: {specific_e}")
|
82 |
# traceback.print_exc() # Keep for debugging if needed, but can be verbose
|
@@ -94,7 +103,7 @@ def load_model(hf_token):
|
|
94 |
model_loaded = False
|
95 |
loaded_model_name = "None"
|
96 |
print("Could not load any model version.")
|
97 |
-
return "β Could not load any model. Please check your token (ensure it has read permissions and you've accepted Gemma's license on Hugging Face) and network connection.",
|
98 |
|
99 |
except Exception as e:
|
100 |
model_loaded = False
|
@@ -104,9 +113,9 @@ def load_model(hf_token):
|
|
104 |
traceback.print_exc()
|
105 |
|
106 |
if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg :
|
107 |
-
return "β Authentication failed. Please check your Hugging Face token and ensure you have accepted the Gemma license agreement on the Hugging Face model page.",
|
108 |
else:
|
109 |
-
return f"β An unexpected error occurred during model loading: {error_msg}",
|
110 |
|
111 |
|
112 |
def generate_prompt(task_type, **kwargs):
|
@@ -133,12 +142,14 @@ def generate_prompt(task_type, **kwargs):
|
|
133 |
prompt_template = prompts.get(task_type)
|
134 |
if prompt_template:
|
135 |
try:
|
136 |
-
#
|
137 |
-
#
|
138 |
-
|
139 |
-
final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in
|
140 |
-
|
|
|
141 |
final_kwargs.update(kwargs)
|
|
|
142 |
return prompt_template.format(**final_kwargs)
|
143 |
except KeyError as e:
|
144 |
print(f"Warning: Missing key for prompt template '{task_type}': {e}")
|
@@ -166,9 +177,14 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
166 |
|
167 |
try:
|
168 |
# Add role/turn indicators if using an instruction-tuned model
|
169 |
-
|
|
|
170 |
# Simple chat structure assumed by many instruction models
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
else:
|
173 |
# Base models might not need specific turn indicators
|
174 |
chat_prompt = prompt
|
@@ -177,13 +193,9 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
177 |
input_length = inputs.input_ids.shape[1]
|
178 |
print(f"Input token length: {input_length}")
|
179 |
|
180 |
-
# Adjust max_length based on input, prevent it from being too small
|
181 |
-
# max_length = max(input_length + 64, input_length + max_new_tokens) # Ensure at least some generation
|
182 |
-
# Use max_new_tokens directly as it's clearer for users
|
183 |
# Ensure max_new_tokens isn't excessively large for the model context
|
184 |
-
#
|
185 |
-
#
|
186 |
-
effective_max_new_tokens = min(max_new_tokens, 2048) # Cap generation length
|
187 |
|
188 |
generation_args = {
|
189 |
"input_ids": inputs.input_ids,
|
@@ -192,7 +204,7 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
192 |
"do_sample": True,
|
193 |
"temperature": float(temperature), # Ensure float
|
194 |
"top_p": float(top_p), # Ensure float
|
195 |
-
"pad_token_id": global_tokenizer.eos_token_id # Use EOS token for padding
|
196 |
}
|
197 |
|
198 |
print(f"Generation args: {generation_args}")
|
@@ -201,7 +213,6 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
201 |
with torch.no_grad(): # Disable gradient calculation for inference
|
202 |
outputs = global_model.generate(**generation_args)
|
203 |
|
204 |
-
# Decode response, skipping special tokens and the prompt
|
205 |
# Decode only the newly generated tokens
|
206 |
generated_ids = outputs[0, input_length:]
|
207 |
generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True)
|
@@ -215,7 +226,13 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
215 |
print(f"Generation error: {error_msg}")
|
216 |
print(f"Error type: {type(e)}")
|
217 |
traceback.print_exc()
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
# Create parameters UI component (reusable function)
|
221 |
def create_parameter_ui():
|
@@ -237,7 +254,7 @@ def create_parameter_ui():
|
|
237 |
value=0.7,
|
238 |
step=0.1,
|
239 |
label="Temperature",
|
240 |
-
info="Controls randomness. Lower is
|
241 |
elem_id="temperature_slider"
|
242 |
)
|
243 |
top_p = gr.Slider(
|
@@ -253,7 +270,7 @@ def create_parameter_ui():
|
|
253 |
|
254 |
# --- Gradio Interface ---
|
255 |
# Use the soft theme for a clean look, allow light/dark switching
|
256 |
-
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
257 |
|
258 |
# Header
|
259 |
gr.Markdown(
|
@@ -274,7 +291,9 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
274 |
)
|
275 |
|
276 |
# --- Authentication Section ---
|
277 |
-
|
|
|
|
|
278 |
with gr.Row():
|
279 |
with gr.Column(scale=4):
|
280 |
hf_token = gr.Textbox(
|
@@ -282,38 +301,31 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
282 |
placeholder="Paste your HF token here (hf_...)",
|
283 |
type="password",
|
284 |
value=DEFAULT_HF_TOKEN,
|
|
|
285 |
elem_id="hf_token_input"
|
286 |
)
|
287 |
with gr.Column(scale=1, min_width=150):
|
288 |
-
# Add spacer for alignment if needed, or adjust scale
|
289 |
-
# gr.Spacer(height=10) # Add space above button if needed
|
290 |
auth_button = gr.Button("Load Model", variant="primary", elem_id="auth_button")
|
291 |
|
292 |
auth_status = gr.Markdown("βΉοΈ Enter your Hugging Face token and click 'Load Model'. This might take a minute.", elem_id="auth_status")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
-
# Define authentication flow (simplified)
|
295 |
-
def handle_auth(token):
|
296 |
-
# Show loading message immediately
|
297 |
-
yield "β³ Authenticating and loading model... Please wait.", gr.Tabs.update(visible=False)
|
298 |
-
# Call the actual model loading function
|
299 |
-
status_message, tabs_update = load_model(token)
|
300 |
-
yield status_message, tabs_update
|
301 |
-
|
302 |
-
# Link button click to the handler
|
303 |
-
auth_button.click(
|
304 |
-
fn=handle_auth,
|
305 |
-
inputs=[hf_token],
|
306 |
-
outputs=[auth_status, gr.get_component("main_tabs")], # Update status and hide/show main_tabs by element id
|
307 |
-
queue=True # Run in queue for potentially long operation
|
308 |
-
)
|
309 |
|
310 |
# --- Main Content Tabs (Initially Hidden) ---
|
311 |
-
#
|
312 |
with gr.Tabs(elem_id="main_tabs", visible=False) as tabs:
|
313 |
|
314 |
# --- Text Generation Tab ---
|
315 |
with gr.TabItem("π Creative & Informational", id="tab_text_gen"):
|
316 |
-
with gr.Row():
|
317 |
# Input Column
|
318 |
with gr.Column(scale=1):
|
319 |
gr.Markdown("### Configure Task")
|
@@ -336,12 +348,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
336 |
with gr.Group(visible=False, elem_id="custom_prompt_group") as custom_prompt_group:
|
337 |
custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter your full prompt here...", lines=5, elem_id="custom_prompt")
|
338 |
|
339 |
-
# Show/hide logic
|
340 |
def update_text_gen_visibility(choice):
|
|
|
|
|
|
|
341 |
return {
|
342 |
-
creative_options: gr.update(visible=
|
343 |
-
info_options: gr.update(visible=
|
344 |
-
custom_prompt_group: gr.update(visible=
|
345 |
}
|
346 |
text_gen_type.change(update_text_gen_visibility, inputs=text_gen_type, outputs=[creative_options, info_options, custom_prompt_group], queue=False)
|
347 |
|
@@ -353,19 +368,28 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
353 |
# Output Column
|
354 |
with gr.Column(scale=1):
|
355 |
gr.Markdown("### Generated Output")
|
356 |
-
text_output = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="text_output")
|
357 |
|
358 |
# Handler
|
359 |
def text_generation_handler(gen_type, style, creative_topic, format_type, info_topic, custom_prompt_text, max_tokens, temp, top_p_val):
|
360 |
task_map = {
|
361 |
"Creative Writing": ("creative", {"style": style, "topic": creative_topic}),
|
362 |
"Informational Writing": ("informational", {"format_type": format_type, "topic": info_topic}),
|
363 |
-
"Custom Prompt": ("custom", {"prompt": custom_prompt_text})
|
364 |
}
|
|
|
365 |
task_type, kwargs = task_map.get(gen_type, ("custom", {"prompt": custom_prompt_text}))
|
366 |
-
|
367 |
-
for
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
final_prompt = generate_prompt(task_type, **kwargs)
|
371 |
return generate_text(final_prompt, max_tokens, temp, top_p_val)
|
@@ -377,21 +401,24 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
377 |
)
|
378 |
|
379 |
# Examples
|
|
|
380 |
gr.Examples(
|
381 |
examples=[
|
382 |
["Creative Writing", "poem", "the sound of rain on a tin roof", "", "", "", 512, 0.7, 0.9],
|
383 |
["Informational Writing", "", "", "explanation", "how photosynthesis works", "", 768, 0.6, 0.9],
|
384 |
["Custom Prompt", "", "", "", "", "Write a short dialogue between a cat and a dog discussing their humans.", 512, 0.8, 0.95],
|
385 |
],
|
386 |
-
|
|
|
387 |
outputs=text_output,
|
388 |
label="Try these examples...",
|
389 |
-
fn=text_generation_handler #
|
390 |
)
|
391 |
|
|
|
392 |
# --- Brainstorming Tab ---
|
393 |
with gr.TabItem("π§ Brainstorming", id="tab_brainstorm"):
|
394 |
-
with gr.Row():
|
395 |
# Input Column
|
396 |
with gr.Column(scale=1):
|
397 |
gr.Markdown("### Brainstorming Setup")
|
@@ -404,7 +431,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
404 |
# Output Column
|
405 |
with gr.Column(scale=1):
|
406 |
gr.Markdown("### Generated Ideas")
|
407 |
-
brainstorm_output = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="brainstorm_output")
|
408 |
|
409 |
# Handler
|
410 |
def brainstorm_handler(category, topic, max_tokens, temp, top_p_val):
|
@@ -421,22 +448,24 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
421 |
["business", "eco-friendly subscription boxes", 768, 0.75, 0.9],
|
422 |
["creative", "themes for a fantasy novel", 512, 0.85, 0.95],
|
423 |
],
|
424 |
-
inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params],
|
425 |
outputs=brainstorm_output,
|
426 |
label="Try these examples...",
|
427 |
-
fn=brainstorm_handler
|
428 |
)
|
429 |
|
430 |
# --- Code Capabilities Tab ---
|
431 |
with gr.TabItem("π» Code", id="tab_code"):
|
|
|
|
|
|
|
432 |
with gr.Tabs() as code_tabs:
|
433 |
# --- Code Generation ---
|
434 |
with gr.TabItem("Generate Code", id="subtab_code_gen"):
|
435 |
-
with gr.Row():
|
436 |
# Input Column
|
437 |
with gr.Column(scale=1):
|
438 |
gr.Markdown("### Code Generation Setup")
|
439 |
-
code_language_gen = gr.Dropdown([
|
440 |
code_task = gr.Textbox(label="Task Description", placeholder="e.g., function to calculate factorial", value="create a Python class for a basic calculator", lines=4, elem_id="code_task")
|
441 |
code_gen_params = create_parameter_ui()
|
442 |
gr.Spacer(height=15)
|
@@ -445,9 +474,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
445 |
# Output Column
|
446 |
with gr.Column(scale=1):
|
447 |
gr.Markdown("### Generated Code")
|
448 |
-
|
449 |
-
lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust"}
|
450 |
-
code_output = gr.Code(label="Result", language="python", lines=25, interactive=False, elem_id="code_output")
|
451 |
|
452 |
# Handler
|
453 |
def code_gen_handler(language, task, max_tokens, temp, top_p_val):
|
@@ -457,19 +484,27 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
457 |
result = generate_text(prompt, max_tokens, temp, top_p_val)
|
458 |
# Try to extract code block if markdown is used
|
459 |
if "```" in result:
|
460 |
-
|
461 |
-
if len(
|
462 |
-
|
463 |
-
|
464 |
-
if
|
465 |
-
|
466 |
-
|
467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
|
470 |
# Update output language display based on dropdown
|
471 |
def update_code_language_display(lang):
|
472 |
-
return gr.Code(language=lang_map.get(lang, "plaintext")) #
|
473 |
|
474 |
code_language_gen.change(update_code_language_display, inputs=code_language_gen, outputs=code_output, queue=False)
|
475 |
code_gen_btn.click(code_gen_handler, inputs=[code_language_gen, code_task, *code_gen_params], outputs=code_output)
|
@@ -480,20 +515,18 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
480 |
["SQL", "query to select users older than 30 from a 'users' table", 512, 0.5, 0.8],
|
481 |
["HTML", "basic structure for a personal portfolio website", 1024, 0.7, 0.9],
|
482 |
],
|
483 |
-
inputs=[code_language_gen, code_task, *code_gen_params],
|
484 |
outputs=code_output,
|
485 |
label="Try these examples...",
|
486 |
-
fn=code_gen_handler
|
487 |
)
|
488 |
|
489 |
# --- Code Explanation ---
|
490 |
with gr.TabItem("Explain Code", id="subtab_code_explain"):
|
491 |
-
with gr.Row():
|
492 |
# Input Column
|
493 |
with gr.Column(scale=1):
|
494 |
gr.Markdown("### Code Explanation Setup")
|
495 |
-
|
496 |
-
code_language_explain = gr.Dropdown(["Python", "JavaScript", "Java", "C++", "HTML", "CSS", "SQL", "Bash", "Rust", "Other"], label="Code Language (for context)", value="Python", elem_id="code_language_explain")
|
497 |
code_to_explain = gr.Code(label="Paste Code Here", language="python", lines=15, elem_id="code_to_explain")
|
498 |
explain_code_params = create_parameter_ui()
|
499 |
gr.Spacer(height=15)
|
@@ -502,29 +535,29 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
502 |
# Output Column
|
503 |
with gr.Column(scale=1):
|
504 |
gr.Markdown("### Explanation")
|
505 |
-
code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="code_explanation")
|
506 |
|
507 |
# Update code input language display
|
508 |
def update_explain_language_display(lang):
|
509 |
-
return gr.Code(language=lang_map.get(lang, "plaintext"))
|
510 |
code_language_explain.change(update_explain_language_display, inputs=code_language_explain, outputs=code_to_explain, queue=False)
|
511 |
|
512 |
# Handler
|
513 |
def explain_code_handler(language, code, max_tokens, temp, top_p_val):
|
514 |
-
|
515 |
language = safe_value(language, "code") # Use selected language in prompt
|
516 |
-
prompt = generate_prompt("code_explain", language=language, code=
|
517 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
518 |
|
519 |
explain_code_btn.click(explain_code_handler, inputs=[code_language_explain, code_to_explain, *explain_code_params], outputs=code_explanation)
|
520 |
|
521 |
# --- Code Debugging ---
|
522 |
with gr.TabItem("Debug Code", id="subtab_code_debug"):
|
523 |
-
with gr.Row():
|
524 |
# Input Column
|
525 |
with gr.Column(scale=1):
|
526 |
gr.Markdown("### Code Debugging Setup")
|
527 |
-
code_language_debug = gr.Dropdown(
|
528 |
code_to_debug = gr.Code(
|
529 |
label="Paste Potentially Buggy Code Here",
|
530 |
language="python",
|
@@ -539,18 +572,18 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
539 |
# Output Column
|
540 |
with gr.Column(scale=1):
|
541 |
gr.Markdown("### Debugging Analysis & Fix")
|
542 |
-
debug_result = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="debug_result")
|
543 |
|
544 |
# Update code input language display
|
545 |
def update_debug_language_display(lang):
|
546 |
-
return gr.Code(language=lang_map.get(lang, "plaintext"))
|
547 |
code_language_debug.change(update_debug_language_display, inputs=code_language_debug, outputs=code_to_debug, queue=False)
|
548 |
|
549 |
# Handler
|
550 |
def debug_code_handler(language, code, max_tokens, temp, top_p_val):
|
551 |
-
|
552 |
language = safe_value(language, "code")
|
553 |
-
prompt = generate_prompt("code_debug", language=language, code=
|
554 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
555 |
|
556 |
debug_code_btn.click(debug_code_handler, inputs=[code_language_debug, code_to_debug, *debug_code_params], outputs=debug_result)
|
@@ -562,7 +595,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
562 |
|
563 |
# --- Summarization ---
|
564 |
with gr.TabItem("Summarize", id="subtab_summarize"):
|
565 |
-
with gr.Row():
|
566 |
# Input Column
|
567 |
with gr.Column(scale=1):
|
568 |
gr.Markdown("### Summarization Setup")
|
@@ -573,13 +606,13 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
573 |
# Output Column
|
574 |
with gr.Column(scale=1):
|
575 |
gr.Markdown("### Summary")
|
576 |
-
summary_output = gr.Textbox(label="Result", lines=15, interactive=False, elem_id="summary_output")
|
577 |
|
578 |
# Handler
|
579 |
def summarize_handler(text, max_tokens, temp, top_p_val):
|
580 |
text = safe_value(text, "Please provide text to summarize.")
|
581 |
-
# Use shorter max_tokens default for summary
|
582 |
-
max_tokens = min(max_tokens, 512)
|
583 |
prompt = generate_prompt("summarize", text=text)
|
584 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
585 |
|
@@ -587,7 +620,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
587 |
|
588 |
# --- Question Answering ---
|
589 |
with gr.TabItem("Q & A", id="subtab_qa"):
|
590 |
-
with gr.Row():
|
591 |
# Input Column
|
592 |
with gr.Column(scale=1):
|
593 |
gr.Markdown("### Question Answering Setup")
|
@@ -599,14 +632,14 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
599 |
# Output Column
|
600 |
with gr.Column(scale=1):
|
601 |
gr.Markdown("### Answer")
|
602 |
-
qa_output = gr.Textbox(label="Result", lines=10, interactive=False, elem_id="qa_output")
|
603 |
|
604 |
# Handler
|
605 |
def qa_handler(text, question, max_tokens, temp, top_p_val):
|
606 |
text = safe_value(text, "Please provide context text.")
|
607 |
question = safe_value(question, "What is the main point?")
|
608 |
# Use shorter max_tokens default for QA
|
609 |
-
max_tokens = min(max_tokens, 256)
|
610 |
prompt = generate_prompt("qa", text=text, question=question)
|
611 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
612 |
|
@@ -614,7 +647,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
614 |
|
615 |
# --- Translation ---
|
616 |
with gr.TabItem("Translate", id="subtab_translate"):
|
617 |
-
with gr.Row():
|
618 |
# Input Column
|
619 |
with gr.Column(scale=1):
|
620 |
gr.Markdown("### Translation Setup")
|
@@ -629,13 +662,15 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
629 |
# Output Column
|
630 |
with gr.Column(scale=1):
|
631 |
gr.Markdown("### Translation")
|
632 |
-
translation_output = gr.Textbox(label="Result", lines=8, interactive=False, elem_id="translation_output")
|
633 |
|
634 |
# Handler
|
635 |
def translate_handler(text, lang, max_tokens, temp, top_p_val):
|
636 |
text = safe_value(text, "Please enter text to translate.")
|
637 |
lang = safe_value(lang, "French")
|
638 |
prompt = generate_prompt("translate", text=text, target_lang=lang)
|
|
|
|
|
639 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
640 |
|
641 |
translate_btn.click(translate_handler, inputs=[translate_text, target_lang, *translate_params], outputs=translation_output)
|
@@ -647,7 +682,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
647 |
|
648 |
# --- Content Creation ---
|
649 |
with gr.TabItem("Content Creation", id="tab_content"):
|
650 |
-
with gr.Row():
|
651 |
with gr.Column(scale=1):
|
652 |
gr.Markdown("### Content Setup")
|
653 |
content_type = gr.Dropdown(["blog post outline", "social media post (Twitter)", "social media post (LinkedIn)", "marketing email subject line", "product description", "press release intro"], label="Content Type", value="blog post outline", elem_id="content_type")
|
@@ -658,7 +693,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
658 |
content_btn = gr.Button("Generate Content", variant="primary", elem_id="content_btn")
|
659 |
with gr.Column(scale=1):
|
660 |
gr.Markdown("### Generated Content")
|
661 |
-
content_output = gr.Textbox(label="Result", lines=20, interactive=False, elem_id="content_output")
|
662 |
|
663 |
def content_handler(c_type, topic, audience, max_tok, temp, top_p_val):
|
664 |
c_type = safe_value(c_type, "text")
|
@@ -671,7 +706,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
671 |
|
672 |
# --- Email Drafting ---
|
673 |
with gr.TabItem("Email Drafting", id="tab_email"):
|
674 |
-
with gr.Row():
|
675 |
with gr.Column(scale=1):
|
676 |
gr.Markdown("### Email Setup")
|
677 |
email_type = gr.Dropdown(["job inquiry", "meeting request", "follow-up", "thank you note", "customer support response", "sales outreach"], label="Email Type", value="meeting request", elem_id="email_type")
|
@@ -681,7 +716,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
681 |
email_btn = gr.Button("Generate Email Draft", variant="primary", elem_id="email_btn")
|
682 |
with gr.Column(scale=1):
|
683 |
gr.Markdown("### Generated Email")
|
684 |
-
email_output = gr.Textbox(label="Result", lines=20, interactive=False, elem_id="email_output")
|
685 |
|
686 |
def email_handler(e_type, context, max_tok, temp, top_p_val):
|
687 |
e_type = safe_value(e_type, "professional")
|
@@ -693,7 +728,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
693 |
|
694 |
# --- Document Editing ---
|
695 |
with gr.TabItem("Document Editing", id="tab_edit"):
|
696 |
-
with gr.Row():
|
697 |
with gr.Column(scale=1):
|
698 |
gr.Markdown("### Editing Setup")
|
699 |
edit_text = gr.Textbox(label="Text to Edit", placeholder="Paste text here...", lines=10, elem_id="edit_text")
|
@@ -703,20 +738,22 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
703 |
edit_btn = gr.Button("Edit Text", variant="primary", elem_id="edit_btn")
|
704 |
with gr.Column(scale=1):
|
705 |
gr.Markdown("### Edited Text")
|
706 |
-
edit_output = gr.Textbox(label="Result", lines=10, interactive=False, elem_id="edit_output")
|
707 |
|
708 |
def edit_handler(text, e_type, max_tok, temp, top_p_val):
|
709 |
text = safe_value(text, "Provide text to edit.")
|
710 |
e_type = safe_value(e_type, "clarity and grammar")
|
711 |
prompt = generate_prompt("document_edit", text=text, edit_type=e_type)
|
712 |
-
#
|
713 |
-
|
|
|
|
|
714 |
edit_btn.click(edit_handler, inputs=[edit_text, edit_type, *edit_params], outputs=edit_output)
|
715 |
|
716 |
|
717 |
# --- Classification ---
|
718 |
with gr.TabItem("Classification", id="tab_classify"):
|
719 |
-
with gr.Row():
|
720 |
with gr.Column(scale=1):
|
721 |
gr.Markdown("### Classification Setup")
|
722 |
classify_text = gr.Textbox(label="Text to Classify", placeholder="Enter text...", lines=8, value="This new sci-fi movie explores themes of AI consciousness and interstellar travel.")
|
@@ -726,21 +763,33 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
726 |
classify_btn = gr.Button("Classify Text", variant="primary")
|
727 |
with gr.Column(scale=1):
|
728 |
gr.Markdown("### Classification Result")
|
729 |
-
classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False)
|
730 |
|
731 |
def classify_handler(text, cats, max_tok, temp, top_p_val):
|
732 |
text = safe_value(text, "Text to classify needed.")
|
733 |
cats = safe_value(cats, "category1, category2")
|
734 |
# Classification usually needs short output
|
735 |
-
max_tok = min(max_tok,
|
736 |
prompt = generate_prompt("classify", text=text, categories=cats)
|
737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
738 |
classify_btn.click(classify_handler, inputs=[classify_text, classify_categories, *classify_params], outputs=classify_output)
|
739 |
|
740 |
|
741 |
# --- Data Extraction ---
|
742 |
with gr.TabItem("Data Extraction", id="tab_extract"):
|
743 |
-
with gr.Row():
|
744 |
with gr.Column(scale=1):
|
745 |
gr.Markdown("### Extraction Setup")
|
746 |
extract_text = gr.Textbox(label="Source Text", placeholder="Paste text containing data...", lines=10, value="Order #12345 placed on 2024-03-15 by Jane Doe (jane.d@email.com). Total amount: $99.95. Shipping to 123 Main St, Anytown, USA.")
|
@@ -750,7 +799,7 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
750 |
extract_btn = gr.Button("Extract Data", variant="primary")
|
751 |
with gr.Column(scale=1):
|
752 |
gr.Markdown("### Extracted Data")
|
753 |
-
extract_output = gr.Textbox(label="Result (e.g., JSON or key-value pairs)", lines=10, interactive=False)
|
754 |
|
755 |
def extract_handler(text, points, max_tok, temp, top_p_val):
|
756 |
text = safe_value(text, "Provide text for extraction.")
|
@@ -760,27 +809,49 @@ with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
|
|
760 |
extract_btn.click(extract_handler, inputs=[extract_text, extract_data_points, *extract_params], outputs=extract_output)
|
761 |
|
762 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
# --- Footer ---
|
764 |
-
gr.Markdown(
|
765 |
-
"""
|
766 |
---
|
767 |
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
768 |
<p>Powered by Google's Gemma models via Hugging Face π€ Transformers & Gradio.</p>
|
769 |
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
770 |
-
<p>Model Loaded: <
|
771 |
</div>
|
772 |
"""
|
773 |
)
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
<
|
781 |
-
|
782 |
-
|
|
|
|
|
|
|
|
|
|
|
783 |
|
784 |
# --- Launch App ---
|
785 |
# Allow built-in theme switching
|
|
|
786 |
demo.launch(share=False, allowed_themes=["light", "dark"])
|
|
|
25 |
"""Load the model with the provided token"""
|
26 |
global global_model, global_tokenizer, model_loaded, loaded_model_name
|
27 |
|
28 |
+
# Initially assume tabs should be hidden until successful load
|
29 |
+
initial_tabs_update = gr.Tabs.update(visible=False)
|
30 |
+
|
31 |
if not hf_token:
|
32 |
model_loaded = False
|
33 |
loaded_model_name = "None"
|
34 |
+
return "β οΈ Please enter your Hugging Face token to use the model.", initial_tabs_update
|
35 |
|
36 |
try:
|
37 |
# Try different model versions from smallest to largest
|
|
|
66 |
print(f"Loading model {model_name}...")
|
67 |
global_model = AutoModelForCausalLM.from_pretrained(
|
68 |
model_name,
|
69 |
+
# torch_dtype=torch.bfloat16, # Use bfloat16 for better performance/compatibility if available - fallback to float16 if needed
|
70 |
+
torch_dtype=torch.float16, # Using float16 for broader compatibility
|
71 |
device_map="auto", # Let HF decide device placement
|
72 |
token=current_token
|
73 |
)
|
|
|
76 |
model_loaded = True
|
77 |
loaded_model_name = model_name
|
78 |
loaded_successfully = True
|
79 |
+
tabs_update = gr.Tabs.update(visible=True) # Show tabs on success
|
80 |
if is_fallback:
|
81 |
+
return f"β
Fallback model '{model_name}' loaded successfully! Limited capabilities compared to Gemma.", tabs_update
|
82 |
else:
|
83 |
+
return f"β
Model '{model_name}' loaded successfully!", tabs_update
|
84 |
|
85 |
+
except ImportError as import_err:
|
86 |
+
# Handle potential missing dependencies like bitsandbytes if bfloat16 fails
|
87 |
+
print(f"Import Error loading {model_name}: {import_err}. Check dependencies.")
|
88 |
+
continue # Try next model
|
89 |
except Exception as specific_e:
|
90 |
print(f"Failed to load {model_name}: {specific_e}")
|
91 |
# traceback.print_exc() # Keep for debugging if needed, but can be verbose
|
|
|
103 |
model_loaded = False
|
104 |
loaded_model_name = "None"
|
105 |
print("Could not load any model version.")
|
106 |
+
return "β Could not load any model. Please check your token (ensure it has read permissions and you've accepted Gemma's license on Hugging Face) and network connection.", initial_tabs_update
|
107 |
|
108 |
except Exception as e:
|
109 |
model_loaded = False
|
|
|
113 |
traceback.print_exc()
|
114 |
|
115 |
if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg :
|
116 |
+
return "β Authentication failed. Please check your Hugging Face token and ensure you have accepted the Gemma license agreement on the Hugging Face model page.", initial_tabs_update
|
117 |
else:
|
118 |
+
return f"β An unexpected error occurred during model loading: {error_msg}", initial_tabs_update
|
119 |
|
120 |
|
121 |
def generate_prompt(task_type, **kwargs):
|
|
|
142 |
prompt_template = prompts.get(task_type)
|
143 |
if prompt_template:
|
144 |
try:
|
145 |
+
# Prepare kwargs safely for formatting
|
146 |
+
# Find placeholders like {key}
|
147 |
+
keys_in_template = [k[1:-1] for k in prompt_template.split('{') if '}' in k for k in [k.split('}')[0]]]
|
148 |
+
final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in keys_in_template} # Use default if key missing in input kwargs
|
149 |
+
|
150 |
+
# Add any extra kwargs provided that weren't in the template (e.g., for 'custom' type)
|
151 |
final_kwargs.update(kwargs)
|
152 |
+
|
153 |
return prompt_template.format(**final_kwargs)
|
154 |
except KeyError as e:
|
155 |
print(f"Warning: Missing key for prompt template '{task_type}': {e}")
|
|
|
177 |
|
178 |
try:
|
179 |
# Add role/turn indicators if using an instruction-tuned model
|
180 |
+
# Simple check based on model name conventions
|
181 |
+
if loaded_model_name and ("it" in loaded_model_name.lower() or "instruct" in loaded_model_name.lower() or "chat" in loaded_model_name.lower()):
|
182 |
# Simple chat structure assumed by many instruction models
|
183 |
+
# Using Gemma's specific format if it's a Gemma IT model
|
184 |
+
if "gemma" in loaded_model_name.lower():
|
185 |
+
chat_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
186 |
+
else: # Generic instruction format
|
187 |
+
chat_prompt = f"User: {prompt}\nAssistant:"
|
188 |
else:
|
189 |
# Base models might not need specific turn indicators
|
190 |
chat_prompt = prompt
|
|
|
193 |
input_length = inputs.input_ids.shape[1]
|
194 |
print(f"Input token length: {input_length}")
|
195 |
|
|
|
|
|
|
|
196 |
# Ensure max_new_tokens isn't excessively large for the model context
|
197 |
+
# Cap generation length for stability
|
198 |
+
effective_max_new_tokens = min(int(max_new_tokens), 2048) # Cap generation length and ensure int
|
|
|
199 |
|
200 |
generation_args = {
|
201 |
"input_ids": inputs.input_ids,
|
|
|
204 |
"do_sample": True,
|
205 |
"temperature": float(temperature), # Ensure float
|
206 |
"top_p": float(top_p), # Ensure float
|
207 |
+
"pad_token_id": global_tokenizer.eos_token_id if global_tokenizer.eos_token_id is not None else 50256 # Use EOS token for padding, provide fallback
|
208 |
}
|
209 |
|
210 |
print(f"Generation args: {generation_args}")
|
|
|
213 |
with torch.no_grad(): # Disable gradient calculation for inference
|
214 |
outputs = global_model.generate(**generation_args)
|
215 |
|
|
|
216 |
# Decode only the newly generated tokens
|
217 |
generated_ids = outputs[0, input_length:]
|
218 |
generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
|
226 |
print(f"Generation error: {error_msg}")
|
227 |
print(f"Error type: {type(e)}")
|
228 |
traceback.print_exc()
|
229 |
+
# Check for common CUDA errors
|
230 |
+
if "CUDA out of memory" in error_msg:
|
231 |
+
return f"β Error: CUDA out of memory. Try reducing 'Max New Tokens' or using a smaller model variant if possible."
|
232 |
+
elif "probability tensor contains nan" in error_msg:
|
233 |
+
return f"β Error: Generation failed (NaN probability). Try adjusting Temperature/Top-P or modifying the prompt."
|
234 |
+
else:
|
235 |
+
return f"β Error during text generation: {error_msg}\n\nPlease check the logs or try adjusting parameters."
|
236 |
|
237 |
# Create parameters UI component (reusable function)
|
238 |
def create_parameter_ui():
|
|
|
254 |
value=0.7,
|
255 |
step=0.1,
|
256 |
label="Temperature",
|
257 |
+
info="Controls randomness. Lower is focused, higher is diverse.",
|
258 |
elem_id="temperature_slider"
|
259 |
)
|
260 |
top_p = gr.Slider(
|
|
|
270 |
|
271 |
# --- Gradio Interface ---
|
272 |
# Use the soft theme for a clean look, allow light/dark switching
|
273 |
+
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, title="Gemma Capabilities Demo") as demo:
|
274 |
|
275 |
# Header
|
276 |
gr.Markdown(
|
|
|
291 |
)
|
292 |
|
293 |
# --- Authentication Section ---
|
294 |
+
# REMOVED variant="panel" from gr.Group for compatibility
|
295 |
+
with gr.Group(): # Use default Group appearance
|
296 |
+
gr.Markdown("### π Authentication") # Added heading inside group
|
297 |
with gr.Row():
|
298 |
with gr.Column(scale=4):
|
299 |
hf_token = gr.Textbox(
|
|
|
301 |
placeholder="Paste your HF token here (hf_...)",
|
302 |
type="password",
|
303 |
value=DEFAULT_HF_TOKEN,
|
304 |
+
info="Get your token from https://huggingface.co/settings/tokens",
|
305 |
elem_id="hf_token_input"
|
306 |
)
|
307 |
with gr.Column(scale=1, min_width=150):
|
|
|
|
|
308 |
auth_button = gr.Button("Load Model", variant="primary", elem_id="auth_button")
|
309 |
|
310 |
auth_status = gr.Markdown("βΉοΈ Enter your Hugging Face token and click 'Load Model'. This might take a minute.", elem_id="auth_status")
|
311 |
+
# Add instructions on getting token inside the auth group
|
312 |
+
gr.Markdown(
|
313 |
+
"""
|
314 |
+
**How to get a token:**
|
315 |
+
1. Go to [Hugging Face Token Settings](https://huggingface.co/settings/tokens)
|
316 |
+
2. Create a new token with **read** access.
|
317 |
+
3. Ensure you've accepted the [Gemma model license](https://huggingface.co/google/gemma-7b-it) on the model page.
|
318 |
+
"""
|
319 |
+
)
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
# --- Main Content Tabs (Initially Hidden) ---
|
323 |
+
# Define the tabs variable here
|
324 |
with gr.Tabs(elem_id="main_tabs", visible=False) as tabs:
|
325 |
|
326 |
# --- Text Generation Tab ---
|
327 |
with gr.TabItem("π Creative & Informational", id="tab_text_gen"):
|
328 |
+
with gr.Row(equal_height=False): # Allow columns to have different heights if needed
|
329 |
# Input Column
|
330 |
with gr.Column(scale=1):
|
331 |
gr.Markdown("### Configure Task")
|
|
|
348 |
with gr.Group(visible=False, elem_id="custom_prompt_group") as custom_prompt_group:
|
349 |
custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter your full prompt here...", lines=5, elem_id="custom_prompt")
|
350 |
|
351 |
+
# Show/hide logic (using gr.update for better practice)
|
352 |
def update_text_gen_visibility(choice):
|
353 |
+
is_creative = choice == "Creative Writing"
|
354 |
+
is_info = choice == "Informational Writing"
|
355 |
+
is_custom = choice == "Custom Prompt"
|
356 |
return {
|
357 |
+
creative_options: gr.update(visible=is_creative),
|
358 |
+
info_options: gr.update(visible=is_info),
|
359 |
+
custom_prompt_group: gr.update(visible=is_custom)
|
360 |
}
|
361 |
text_gen_type.change(update_text_gen_visibility, inputs=text_gen_type, outputs=[creative_options, info_options, custom_prompt_group], queue=False)
|
362 |
|
|
|
368 |
# Output Column
|
369 |
with gr.Column(scale=1):
|
370 |
gr.Markdown("### Generated Output")
|
371 |
+
text_output = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="text_output", show_copy_button=True) # Added copy button
|
372 |
|
373 |
# Handler
|
374 |
def text_generation_handler(gen_type, style, creative_topic, format_type, info_topic, custom_prompt_text, max_tokens, temp, top_p_val):
|
375 |
task_map = {
|
376 |
"Creative Writing": ("creative", {"style": style, "topic": creative_topic}),
|
377 |
"Informational Writing": ("informational", {"format_type": format_type, "topic": info_topic}),
|
378 |
+
"Custom Prompt": ("custom", {"prompt": custom_prompt_text}) # Use 'custom' as type, 'prompt' as key
|
379 |
}
|
380 |
+
# Default to custom if type not found (shouldn't happen with Radio)
|
381 |
task_type, kwargs = task_map.get(gen_type, ("custom", {"prompt": custom_prompt_text}))
|
382 |
+
|
383 |
+
# Ensure safe values for specific task types
|
384 |
+
if task_type == "creative":
|
385 |
+
kwargs["style"] = safe_value(style, "story")
|
386 |
+
kwargs["topic"] = safe_value(creative_topic, "a default topic")
|
387 |
+
elif task_type == "informational":
|
388 |
+
kwargs["format_type"] = safe_value(format_type, "article")
|
389 |
+
kwargs["topic"] = safe_value(info_topic, "a default topic")
|
390 |
+
elif task_type == "custom":
|
391 |
+
kwargs["prompt"] = safe_value(custom_prompt_text, "Write something interesting.")
|
392 |
+
|
393 |
|
394 |
final_prompt = generate_prompt(task_type, **kwargs)
|
395 |
return generate_text(final_prompt, max_tokens, temp, top_p_val)
|
|
|
401 |
)
|
402 |
|
403 |
# Examples
|
404 |
+
# Simplified examples list for clarity
|
405 |
gr.Examples(
|
406 |
examples=[
|
407 |
["Creative Writing", "poem", "the sound of rain on a tin roof", "", "", "", 512, 0.7, 0.9],
|
408 |
["Informational Writing", "", "", "explanation", "how photosynthesis works", "", 768, 0.6, 0.9],
|
409 |
["Custom Prompt", "", "", "", "", "Write a short dialogue between a cat and a dog discussing their humans.", 512, 0.8, 0.95],
|
410 |
],
|
411 |
+
# Ensure the order matches the handler's inputs
|
412 |
+
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params[:3]], # Pass only the UI elements needed
|
413 |
outputs=text_output,
|
414 |
label="Try these examples...",
|
415 |
+
#fn=text_generation_handler # fn is deprecated, click event handles execution
|
416 |
)
|
417 |
|
418 |
+
|
419 |
# --- Brainstorming Tab ---
|
420 |
with gr.TabItem("π§ Brainstorming", id="tab_brainstorm"):
|
421 |
+
with gr.Row(equal_height=False):
|
422 |
# Input Column
|
423 |
with gr.Column(scale=1):
|
424 |
gr.Markdown("### Brainstorming Setup")
|
|
|
431 |
# Output Column
|
432 |
with gr.Column(scale=1):
|
433 |
gr.Markdown("### Generated Ideas")
|
434 |
+
brainstorm_output = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="brainstorm_output", show_copy_button=True)
|
435 |
|
436 |
# Handler
|
437 |
def brainstorm_handler(category, topic, max_tokens, temp, top_p_val):
|
|
|
448 |
["business", "eco-friendly subscription boxes", 768, 0.75, 0.9],
|
449 |
["creative", "themes for a fantasy novel", 512, 0.85, 0.95],
|
450 |
],
|
451 |
+
inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params[:3]],
|
452 |
outputs=brainstorm_output,
|
453 |
label="Try these examples...",
|
|
|
454 |
)
|
455 |
|
456 |
# --- Code Capabilities Tab ---
|
457 |
with gr.TabItem("π» Code", id="tab_code"):
|
458 |
+
# Language mapping for syntax highlighting (defined once)
|
459 |
+
lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust", "Other": "plaintext"}
|
460 |
+
|
461 |
with gr.Tabs() as code_tabs:
|
462 |
# --- Code Generation ---
|
463 |
with gr.TabItem("Generate Code", id="subtab_code_gen"):
|
464 |
+
with gr.Row(equal_height=False):
|
465 |
# Input Column
|
466 |
with gr.Column(scale=1):
|
467 |
gr.Markdown("### Code Generation Setup")
|
468 |
+
code_language_gen = gr.Dropdown(list(lang_map.keys())[:-1], label="Language", value="Python", elem_id="code_language_gen") # Exclude 'Other'
|
469 |
code_task = gr.Textbox(label="Task Description", placeholder="e.g., function to calculate factorial", value="create a Python class for a basic calculator", lines=4, elem_id="code_task")
|
470 |
code_gen_params = create_parameter_ui()
|
471 |
gr.Spacer(height=15)
|
|
|
474 |
# Output Column
|
475 |
with gr.Column(scale=1):
|
476 |
gr.Markdown("### Generated Code")
|
477 |
+
code_output = gr.Code(label="Result", language="python", lines=25, interactive=False, elem_id="code_output") # No copy button needed for gr.Code
|
|
|
|
|
478 |
|
479 |
# Handler
|
480 |
def code_gen_handler(language, task, max_tokens, temp, top_p_val):
|
|
|
484 |
result = generate_text(prompt, max_tokens, temp, top_p_val)
|
485 |
# Try to extract code block if markdown is used
|
486 |
if "```" in result:
|
487 |
+
parts = result.split("```")
|
488 |
+
if len(parts) >= 2:
|
489 |
+
code_block = parts[1]
|
490 |
+
# Remove potential language hint (e.g., ```python)
|
491 |
+
if '\n' in code_block:
|
492 |
+
first_line, rest_of_code = code_block.split('\n', 1)
|
493 |
+
if first_line.strip().lower() == language.lower():
|
494 |
+
return rest_of_code.strip()
|
495 |
+
else:
|
496 |
+
# Language hint might be missing or different
|
497 |
+
return code_block.strip()
|
498 |
+
else:
|
499 |
+
# Code block might be single line without language hint after ```
|
500 |
+
return code_block.strip()
|
501 |
+
# Return full result if no markdown block found or extraction failed
|
502 |
+
return result.strip()
|
503 |
|
504 |
|
505 |
# Update output language display based on dropdown
|
506 |
def update_code_language_display(lang):
|
507 |
+
return gr.Code.update(language=lang_map.get(lang, "plaintext")) # Use update method
|
508 |
|
509 |
code_language_gen.change(update_code_language_display, inputs=code_language_gen, outputs=code_output, queue=False)
|
510 |
code_gen_btn.click(code_gen_handler, inputs=[code_language_gen, code_task, *code_gen_params], outputs=code_output)
|
|
|
515 |
["SQL", "query to select users older than 30 from a 'users' table", 512, 0.5, 0.8],
|
516 |
["HTML", "basic structure for a personal portfolio website", 1024, 0.7, 0.9],
|
517 |
],
|
518 |
+
inputs=[code_language_gen, code_task, *code_gen_params[:3]],
|
519 |
outputs=code_output,
|
520 |
label="Try these examples...",
|
|
|
521 |
)
|
522 |
|
523 |
# --- Code Explanation ---
|
524 |
with gr.TabItem("Explain Code", id="subtab_code_explain"):
|
525 |
+
with gr.Row(equal_height=False):
|
526 |
# Input Column
|
527 |
with gr.Column(scale=1):
|
528 |
gr.Markdown("### Code Explanation Setup")
|
529 |
+
code_language_explain = gr.Dropdown(list(lang_map.keys()), label="Code Language (for context)", value="Python", elem_id="code_language_explain")
|
|
|
530 |
code_to_explain = gr.Code(label="Paste Code Here", language="python", lines=15, elem_id="code_to_explain")
|
531 |
explain_code_params = create_parameter_ui()
|
532 |
gr.Spacer(height=15)
|
|
|
535 |
# Output Column
|
536 |
with gr.Column(scale=1):
|
537 |
gr.Markdown("### Explanation")
|
538 |
+
code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="code_explanation", show_copy_button=True)
|
539 |
|
540 |
# Update code input language display
|
541 |
def update_explain_language_display(lang):
|
542 |
+
return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
543 |
code_language_explain.change(update_explain_language_display, inputs=code_language_explain, outputs=code_to_explain, queue=False)
|
544 |
|
545 |
# Handler
|
546 |
def explain_code_handler(language, code, max_tokens, temp, top_p_val):
|
547 |
+
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Add code here") # Handle potential dict input from gr.Code
|
548 |
language = safe_value(language, "code") # Use selected language in prompt
|
549 |
+
prompt = generate_prompt("code_explain", language=language, code=code_content)
|
550 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
551 |
|
552 |
explain_code_btn.click(explain_code_handler, inputs=[code_language_explain, code_to_explain, *explain_code_params], outputs=code_explanation)
|
553 |
|
554 |
# --- Code Debugging ---
|
555 |
with gr.TabItem("Debug Code", id="subtab_code_debug"):
|
556 |
+
with gr.Row(equal_height=False):
|
557 |
# Input Column
|
558 |
with gr.Column(scale=1):
|
559 |
gr.Markdown("### Code Debugging Setup")
|
560 |
+
code_language_debug = gr.Dropdown(list(lang_map.keys()), label="Code Language (for context)", value="Python", elem_id="code_language_debug")
|
561 |
code_to_debug = gr.Code(
|
562 |
label="Paste Potentially Buggy Code Here",
|
563 |
language="python",
|
|
|
572 |
# Output Column
|
573 |
with gr.Column(scale=1):
|
574 |
gr.Markdown("### Debugging Analysis & Fix")
|
575 |
+
debug_result = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="debug_result", show_copy_button=True)
|
576 |
|
577 |
# Update code input language display
|
578 |
def update_debug_language_display(lang):
|
579 |
+
return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
580 |
code_language_debug.change(update_debug_language_display, inputs=code_language_debug, outputs=code_to_debug, queue=False)
|
581 |
|
582 |
# Handler
|
583 |
def debug_code_handler(language, code, max_tokens, temp, top_p_val):
|
584 |
+
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Add potentially buggy code here")
|
585 |
language = safe_value(language, "code")
|
586 |
+
prompt = generate_prompt("code_debug", language=language, code=code_content)
|
587 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
588 |
|
589 |
debug_code_btn.click(debug_code_handler, inputs=[code_language_debug, code_to_debug, *debug_code_params], outputs=debug_result)
|
|
|
595 |
|
596 |
# --- Summarization ---
|
597 |
with gr.TabItem("Summarize", id="subtab_summarize"):
|
598 |
+
with gr.Row(equal_height=False):
|
599 |
# Input Column
|
600 |
with gr.Column(scale=1):
|
601 |
gr.Markdown("### Summarization Setup")
|
|
|
606 |
# Output Column
|
607 |
with gr.Column(scale=1):
|
608 |
gr.Markdown("### Summary")
|
609 |
+
summary_output = gr.Textbox(label="Result", lines=15, interactive=False, elem_id="summary_output", show_copy_button=True)
|
610 |
|
611 |
# Handler
|
612 |
def summarize_handler(text, max_tokens, temp, top_p_val):
|
613 |
text = safe_value(text, "Please provide text to summarize.")
|
614 |
+
# Use shorter max_tokens default for summary, but ensure it's reasonable
|
615 |
+
max_tokens = min(max(int(max_tokens), 64), 512) # Ensure int, set min/max bounds
|
616 |
prompt = generate_prompt("summarize", text=text)
|
617 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
618 |
|
|
|
620 |
|
621 |
# --- Question Answering ---
|
622 |
with gr.TabItem("Q & A", id="subtab_qa"):
|
623 |
+
with gr.Row(equal_height=False):
|
624 |
# Input Column
|
625 |
with gr.Column(scale=1):
|
626 |
gr.Markdown("### Question Answering Setup")
|
|
|
632 |
# Output Column
|
633 |
with gr.Column(scale=1):
|
634 |
gr.Markdown("### Answer")
|
635 |
+
qa_output = gr.Textbox(label="Result", lines=10, interactive=False, elem_id="qa_output", show_copy_button=True)
|
636 |
|
637 |
# Handler
|
638 |
def qa_handler(text, question, max_tokens, temp, top_p_val):
|
639 |
text = safe_value(text, "Please provide context text.")
|
640 |
question = safe_value(question, "What is the main point?")
|
641 |
# Use shorter max_tokens default for QA
|
642 |
+
max_tokens = min(max(int(max_tokens), 32), 256) # Ensure int, set min/max bounds
|
643 |
prompt = generate_prompt("qa", text=text, question=question)
|
644 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
645 |
|
|
|
647 |
|
648 |
# --- Translation ---
|
649 |
with gr.TabItem("Translate", id="subtab_translate"):
|
650 |
+
with gr.Row(equal_height=False):
|
651 |
# Input Column
|
652 |
with gr.Column(scale=1):
|
653 |
gr.Markdown("### Translation Setup")
|
|
|
662 |
# Output Column
|
663 |
with gr.Column(scale=1):
|
664 |
gr.Markdown("### Translation")
|
665 |
+
translation_output = gr.Textbox(label="Result", lines=8, interactive=False, elem_id="translation_output", show_copy_button=True)
|
666 |
|
667 |
# Handler
|
668 |
def translate_handler(text, lang, max_tokens, temp, top_p_val):
|
669 |
text = safe_value(text, "Please enter text to translate.")
|
670 |
lang = safe_value(lang, "French")
|
671 |
prompt = generate_prompt("translate", text=text, target_lang=lang)
|
672 |
+
# Translation length is often similar to input, allow reasonable max_tokens
|
673 |
+
max_tokens = max(int(max_tokens), 64) # Ensure int, set min bound
|
674 |
return generate_text(prompt, max_tokens, temp, top_p_val)
|
675 |
|
676 |
translate_btn.click(translate_handler, inputs=[translate_text, target_lang, *translate_params], outputs=translation_output)
|
|
|
682 |
|
683 |
# --- Content Creation ---
|
684 |
with gr.TabItem("Content Creation", id="tab_content"):
|
685 |
+
with gr.Row(equal_height=False):
|
686 |
with gr.Column(scale=1):
|
687 |
gr.Markdown("### Content Setup")
|
688 |
content_type = gr.Dropdown(["blog post outline", "social media post (Twitter)", "social media post (LinkedIn)", "marketing email subject line", "product description", "press release intro"], label="Content Type", value="blog post outline", elem_id="content_type")
|
|
|
693 |
content_btn = gr.Button("Generate Content", variant="primary", elem_id="content_btn")
|
694 |
with gr.Column(scale=1):
|
695 |
gr.Markdown("### Generated Content")
|
696 |
+
content_output = gr.Textbox(label="Result", lines=20, interactive=False, elem_id="content_output", show_copy_button=True)
|
697 |
|
698 |
def content_handler(c_type, topic, audience, max_tok, temp, top_p_val):
|
699 |
c_type = safe_value(c_type, "text")
|
|
|
706 |
|
707 |
# --- Email Drafting ---
|
708 |
with gr.TabItem("Email Drafting", id="tab_email"):
|
709 |
+
with gr.Row(equal_height=False):
|
710 |
with gr.Column(scale=1):
|
711 |
gr.Markdown("### Email Setup")
|
712 |
email_type = gr.Dropdown(["job inquiry", "meeting request", "follow-up", "thank you note", "customer support response", "sales outreach"], label="Email Type", value="meeting request", elem_id="email_type")
|
|
|
716 |
email_btn = gr.Button("Generate Email Draft", variant="primary", elem_id="email_btn")
|
717 |
with gr.Column(scale=1):
|
718 |
gr.Markdown("### Generated Email")
|
719 |
+
email_output = gr.Textbox(label="Result", lines=20, interactive=False, elem_id="email_output", show_copy_button=True)
|
720 |
|
721 |
def email_handler(e_type, context, max_tok, temp, top_p_val):
|
722 |
e_type = safe_value(e_type, "professional")
|
|
|
728 |
|
729 |
# --- Document Editing ---
|
730 |
with gr.TabItem("Document Editing", id="tab_edit"):
|
731 |
+
with gr.Row(equal_height=False):
|
732 |
with gr.Column(scale=1):
|
733 |
gr.Markdown("### Editing Setup")
|
734 |
edit_text = gr.Textbox(label="Text to Edit", placeholder="Paste text here...", lines=10, elem_id="edit_text")
|
|
|
738 |
edit_btn = gr.Button("Edit Text", variant="primary", elem_id="edit_btn")
|
739 |
with gr.Column(scale=1):
|
740 |
gr.Markdown("### Edited Text")
|
741 |
+
edit_output = gr.Textbox(label="Result", lines=10, interactive=False, elem_id="edit_output", show_copy_button=True)
|
742 |
|
743 |
def edit_handler(text, e_type, max_tok, temp, top_p_val):
|
744 |
text = safe_value(text, "Provide text to edit.")
|
745 |
e_type = safe_value(e_type, "clarity and grammar")
|
746 |
prompt = generate_prompt("document_edit", text=text, edit_type=e_type)
|
747 |
+
# Editing might expand text, give it reasonable token count based on input + max_new
|
748 |
+
input_tokens_estimate = len(text.split()) # Rough estimate
|
749 |
+
max_tok = max(int(max_tok), input_tokens_estimate + 64) # Ensure enough room
|
750 |
+
return generate_text(prompt, max_tok, temp, top_p_val)
|
751 |
edit_btn.click(edit_handler, inputs=[edit_text, edit_type, *edit_params], outputs=edit_output)
|
752 |
|
753 |
|
754 |
# --- Classification ---
|
755 |
with gr.TabItem("Classification", id="tab_classify"):
|
756 |
+
with gr.Row(equal_height=False):
|
757 |
with gr.Column(scale=1):
|
758 |
gr.Markdown("### Classification Setup")
|
759 |
classify_text = gr.Textbox(label="Text to Classify", placeholder="Enter text...", lines=8, value="This new sci-fi movie explores themes of AI consciousness and interstellar travel.")
|
|
|
763 |
classify_btn = gr.Button("Classify Text", variant="primary")
|
764 |
with gr.Column(scale=1):
|
765 |
gr.Markdown("### Classification Result")
|
766 |
+
classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False, show_copy_button=True)
|
767 |
|
768 |
def classify_handler(text, cats, max_tok, temp, top_p_val):
|
769 |
text = safe_value(text, "Text to classify needed.")
|
770 |
cats = safe_value(cats, "category1, category2")
|
771 |
# Classification usually needs short output
|
772 |
+
max_tok = min(max(int(max_tok), 16), 128) # Ensure int, constrain tightly
|
773 |
prompt = generate_prompt("classify", text=text, categories=cats)
|
774 |
+
# Often the model just outputs the category, so we might not need the prompt structure removal
|
775 |
+
raw_output = generate_text(prompt, max_tok, temp, top_p_val)
|
776 |
+
# Post-process to get just the category if possible
|
777 |
+
lines = raw_output.split('\n')
|
778 |
+
if lines:
|
779 |
+
last_line = lines[-1].strip()
|
780 |
+
# Check if the last line seems like one of the categories
|
781 |
+
possible_cats = [c.strip().lower() for c in cats.split(',')]
|
782 |
+
if last_line.lower() in possible_cats:
|
783 |
+
return last_line
|
784 |
+
# Fallback to raw output
|
785 |
+
return raw_output
|
786 |
+
|
787 |
classify_btn.click(classify_handler, inputs=[classify_text, classify_categories, *classify_params], outputs=classify_output)
|
788 |
|
789 |
|
790 |
# --- Data Extraction ---
|
791 |
with gr.TabItem("Data Extraction", id="tab_extract"):
|
792 |
+
with gr.Row(equal_height=False):
|
793 |
with gr.Column(scale=1):
|
794 |
gr.Markdown("### Extraction Setup")
|
795 |
extract_text = gr.Textbox(label="Source Text", placeholder="Paste text containing data...", lines=10, value="Order #12345 placed on 2024-03-15 by Jane Doe (jane.d@email.com). Total amount: $99.95. Shipping to 123 Main St, Anytown, USA.")
|
|
|
799 |
extract_btn = gr.Button("Extract Data", variant="primary")
|
800 |
with gr.Column(scale=1):
|
801 |
gr.Markdown("### Extracted Data")
|
802 |
+
extract_output = gr.Textbox(label="Result (e.g., JSON or key-value pairs)", lines=10, interactive=False, show_copy_button=True)
|
803 |
|
804 |
def extract_handler(text, points, max_tok, temp, top_p_val):
|
805 |
text = safe_value(text, "Provide text for extraction.")
|
|
|
809 |
extract_btn.click(extract_handler, inputs=[extract_text, extract_data_points, *extract_params], outputs=extract_output)
|
810 |
|
811 |
|
812 |
+
# Define authentication handler AFTER tabs is defined
|
813 |
+
def handle_auth(token):
|
814 |
+
# Show loading message immediately
|
815 |
+
yield "β³ Authenticating and loading model... Please wait.", gr.Tabs.update(visible=False)
|
816 |
+
# Call the actual model loading function
|
817 |
+
status_message, tabs_update = load_model(token)
|
818 |
+
yield status_message, tabs_update
|
819 |
+
|
820 |
+
# Link button click to the handler
|
821 |
+
auth_button.click(
|
822 |
+
fn=handle_auth,
|
823 |
+
inputs=[hf_token],
|
824 |
+
outputs=[auth_status, tabs], # Use the defined 'tabs' variable here
|
825 |
+
queue=True # Run in queue for potentially long operation
|
826 |
+
)
|
827 |
+
|
828 |
# --- Footer ---
|
829 |
+
footer_status = gr.Markdown( # Use a separate Markdown for dynamic updates
|
830 |
+
f"""
|
831 |
---
|
832 |
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
833 |
<p>Powered by Google's Gemma models via Hugging Face π€ Transformers & Gradio.</p>
|
834 |
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
835 |
+
<p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p>
|
836 |
</div>
|
837 |
"""
|
838 |
)
|
839 |
+
|
840 |
+
# Update footer when authentication status changes
|
841 |
+
def update_footer_status(status_text):
|
842 |
+
# You could parse status_text, but easier to just use global state here
|
843 |
+
return gr.Markdown.update(value=f"""
|
844 |
+
---
|
845 |
+
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
846 |
+
<p>Powered by Google's Gemma models via Hugging Face π€ Transformers & Gradio.</p>
|
847 |
+
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
848 |
+
<p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p>
|
849 |
+
</div>
|
850 |
+
""")
|
851 |
+
auth_status.change(fn=update_footer_status, inputs=auth_status, outputs=footer_status, queue=False)
|
852 |
+
|
853 |
|
854 |
# --- Launch App ---
|
855 |
# Allow built-in theme switching
|
856 |
+
# Use queue() to handle multiple requests better
|
857 |
demo.launch(share=False, allowed_themes=["light", "dark"])
|