TechnoByte's picture
multi line support
473c11d verified
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
MODEL_NAME = "TechnoByte/Qwen2.5-7B-VNTL-JP-EN"
MAX_NEW_TOKENS = 512 # Max length of the generated translation per line
# --- Load Model and Tokenizer ---
# Load the model and tokenizer only once when the app starts
print(f"Loading model: {MODEL_NAME}...")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto", # Use bfloat16 if available, float16 otherwise
device_map="auto" # Automatically distribute across available GPUs/CPU
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
raise gr.Error(f"Failed to load model: {e}. Check Space logs and hardware.")
# --- Translation Function ---
@spaces.GPU(duration=20)
def translate_japanese_to_english(input_text):
"""
Translates Japanese text to English using the loaded model, processing line by line.
"""
if not input_text:
return "Please enter some Japanese text to translate."
print(f"Received input:\n{input_text}")
lines = input_text.splitlines() # Split input into lines
translated_lines = []
try: # Wrap the entire multi-line processing
for line in lines:
if not line.strip(): # If the line is empty or just whitespace
translated_lines.append("") # Keep the empty line structure
continue # Skip processing for this empty line
print(f"Translating line: {line}")
# Prepare the input for the current line using the chat template
messages = [
{"role": "user", "content": line}
]
# Apply chat template
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize the input for the current line
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(model.device)
# Generate the translation for the current line
generated_ids = model.generate(
**model_inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False # Use greedy decoding for consistency
)
# Decode the generated text, skipping the prompt part
input_ids_len = model_inputs.input_ids.shape[1]
output_ids = generated_ids[0][input_ids_len:]
response = tokenizer.decode(output_ids, skip_special_tokens=True).strip() # Strip leading/trailing whitespace from the translation
print(f"Generated response for line: {response}")
translated_lines.append(response)
# Join the translated lines back together with newline characters
final_translation = "\n".join(translated_lines)
print(f"Final combined translation:\n{final_translation}")
return final_translation
except Exception as e:
print(f"Error during translation: {e}")
return f"An error occurred during translation: {e}"
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# Qwen2.5-7B-VNTL-JP-EN Demo ๐Ÿ‡ฏ๐Ÿ‡ตโžก๏ธ๐Ÿ‡ฌ๐Ÿ‡ง
Enter Japanese text below and click "Translate" to get the English translation.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_textbox = gr.Textbox(
lines=5, # Keep initial size, but it can grow
label="Japanese Input Text"
)
translate_button = gr.Button("Translate", variant="primary")
with gr.Column(scale=1):
output_textbox = gr.Textbox(
lines=5,
label="English Output Text",
interactive=False # Output box should not be editable by user
)
# --- Event Listener ---
translate_button.click(
fn=translate_japanese_to_english,
inputs=input_textbox,
outputs=output_textbox,
api_name="translate" # Expose as API endpoint /api/translate
)
gr.Examples(
examples=[
["ๆ”พ่ชฒๅพŒใฏใƒžใƒณใ‚ฌๅ–ซ่ŒถใงใพใฃใŸใ‚Šใ€œโ™ก ใŠใ™ใ™ใ‚ใฎใƒžใƒณใ‚ฌๆ•™ใˆใฆ๏ผ"],
["ใ“ใฎใ‚ฝใƒ•ใƒˆใ‚ฆใ‚งใ‚ขใฎไฝฟใ„ๆ–นใŒใ‚ˆใใ‚ใ‹ใ‚Šใพใ›ใ‚“ใ€‚"],
["ๆ˜Žๆ—ฅใฎๅคฉๆฐ—ใฏใฉใ†ใชใ‚Šใพใ™ใ‹๏ผŸ"],
["ๆ—ฅๆœฌใฎๆ–‡ๅŒ–ใซใคใ„ใฆใ‚‚ใฃใจ็Ÿฅใ‚ŠใŸใ„ใงใ™ใ€‚"],
["ใ“ใ‚“ใซใกใฏใ€‚\nๅ…ƒๆฐ—ใงใ™ใ‹๏ผŸ\n็งใฏๅ…ƒๆฐ—ใงใ™ใ€‚"], # Multi-line example
["ใ“ใ‚Œใฏๆœ€ๅˆใฎ่กŒใงใ™ใ€‚\n\nใ“ใ‚Œใฏ๏ผ“่กŒ็›ฎใงใ™ใ€‚็ฉบ่กŒใ‚’ๆŒŸใฟใพใ™ใ€‚"] # Example with empty line
],
inputs=input_textbox,
outputs=output_textbox,
fn=translate_japanese_to_english,
cache_examples=True
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch()