rafaaa2105's picture
Update app.py
8a629ad verified
import gradio as gr
import json
import os
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch
import spaces
# --- Model Cache ---
# A global dictionary to cache loaded models and tokenizers to avoid reloading them.
model_cache = {}
def get_model_and_tokenizer(model_name):
"""
Loads and caches the specified model and tokenizer.
Args:
model_name (str): The name of the model to load.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
"""
if model_name in model_cache:
print(f"Using cached model: {model_name}")
return model_cache[model_name]
else:
print(f"Loading model: {model_name}. This may take a moment...")
model = M2M100ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
tokenizer = M2M100Tokenizer.from_pretrained(model_name)
model_cache[model_name] = (model, tokenizer)
print("Model loaded successfully.")
return model, tokenizer
# A list of languages supported by the M2M100 model
supported_languages = [
"af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb",
"cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy",
"ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig",
"ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln",
"lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no",
"ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl",
"so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk",
"ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"
]
@spaces.GPU
def translate_minecraft_json(json_file, target_lang, model_name, progress=gr.Progress()):
"""
Translates the values of a JSON file to a specified language.
Args:
json_file (gradio.File): The uploaded en_us.json file.
target_lang (str): The target language code (e.g., "fr" for French).
model_name (str): The name of the model to use for translation.
progress (gradio.Progress): Gradio progress tracker.
Returns:
str: The file path of the translated JSON file.
"""
if json_file is None:
return None
try:
# Get the selected model and tokenizer
model, tokenizer = get_model_and_tokenizer(model_name)
# Read the uploaded JSON file
with open(json_file.name, 'r', encoding='utf-8') as f:
data = json.load(f)
# Set the source language for the tokenizer
tokenizer.src_lang = "en"
translated_data = {}
total_items = len(data)
# Use the gradio progress tracker
progress(0, desc=f"Starting translation to '{target_lang}' with {model_name}...")
for i, (key, value) in enumerate(data.items()):
if isinstance(value, str) and value.strip(): # Also check if string is not empty
# Tokenize the text
encoded_text = tokenizer(value, return_tensors="pt").to(model.device)
# Generate translation
with torch.no_grad():
generated_tokens = model.generate(
**encoded_text,
forced_bos_token_id=tokenizer.get_lang_id(target_lang)
)
# Decode the translated text
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
translated_data[key] = translated_text
else:
# If the value is not a string or is empty, keep it as is
translated_data[key] = value
# Update progress
progress((i + 1) / total_items, desc=f"Translated {i + 1}/{total_items} items...")
# Create a new file for the translated data
output_filename = f"{target_lang}_{model_name.split('/')[-1]}.json"
with open(output_filename, 'w', encoding='utf-8') as f:
json.dump(translated_data, f, ensure_ascii=False, indent=4)
return output_filename
except Exception as e:
print(f"An error occurred: {e}")
# Optionally, raise a Gradio error to show it in the UI
raise gr.Error(f"An error occurred during translation: {e}")
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Minecraft Language File Translator
Upload your `en_us.json` file, select the model and target language, then click "Translate".
"""
)
with gr.Row():
model_selector = gr.Radio(
choices=["facebook/m2m100_418M", "facebook/m2m100_1.2B"],
label="Select Model",
value="facebook/m2m100_418M"
)
with gr.Row():
json_upload = gr.File(label="Upload en_us.json")
language_dropdown = gr.Dropdown(
choices=supported_languages, label="Select Target Language", value="fr"
)
translate_button = gr.Button("Translate")
translated_file = gr.File(label="Download Translated File")
translate_button.click(
translate_minecraft_json,
inputs=[json_upload, language_dropdown, model_selector],
outputs=translated_file
)
if __name__ == "__main__":
# Launch the Gradio app
demo.launch()