Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
import torch | |
import spaces | |
# --- Model Cache --- | |
# A global dictionary to cache loaded models and tokenizers to avoid reloading them. | |
model_cache = {} | |
def get_model_and_tokenizer(model_name): | |
""" | |
Loads and caches the specified model and tokenizer. | |
Args: | |
model_name (str): The name of the model to load. | |
Returns: | |
tuple: A tuple containing the loaded model and tokenizer. | |
""" | |
if model_name in model_cache: | |
print(f"Using cached model: {model_name}") | |
return model_cache[model_name] | |
else: | |
print(f"Loading model: {model_name}. This may take a moment...") | |
model = M2M100ForConditionalGeneration.from_pretrained(model_name, device_map="auto") | |
tokenizer = M2M100Tokenizer.from_pretrained(model_name) | |
model_cache[model_name] = (model, tokenizer) | |
print("Model loaded successfully.") | |
return model, tokenizer | |
# A list of languages supported by the M2M100 model | |
supported_languages = [ | |
"af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", | |
"cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", | |
"ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", | |
"ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", | |
"lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", | |
"ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", | |
"so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", | |
"ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu" | |
] | |
def translate_minecraft_json(json_file, target_lang, model_name, progress=gr.Progress()): | |
""" | |
Translates the values of a JSON file to a specified language. | |
Args: | |
json_file (gradio.File): The uploaded en_us.json file. | |
target_lang (str): The target language code (e.g., "fr" for French). | |
model_name (str): The name of the model to use for translation. | |
progress (gradio.Progress): Gradio progress tracker. | |
Returns: | |
str: The file path of the translated JSON file. | |
""" | |
if json_file is None: | |
return None | |
try: | |
# Get the selected model and tokenizer | |
model, tokenizer = get_model_and_tokenizer(model_name) | |
# Read the uploaded JSON file | |
with open(json_file.name, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
# Set the source language for the tokenizer | |
tokenizer.src_lang = "en" | |
translated_data = {} | |
total_items = len(data) | |
# Use the gradio progress tracker | |
progress(0, desc=f"Starting translation to '{target_lang}' with {model_name}...") | |
for i, (key, value) in enumerate(data.items()): | |
if isinstance(value, str) and value.strip(): # Also check if string is not empty | |
# Tokenize the text | |
encoded_text = tokenizer(value, return_tensors="pt").to(model.device) | |
# Generate translation | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**encoded_text, | |
forced_bos_token_id=tokenizer.get_lang_id(target_lang) | |
) | |
# Decode the translated text | |
translated_text = tokenizer.batch_decode( | |
generated_tokens, skip_special_tokens=True | |
)[0] | |
translated_data[key] = translated_text | |
else: | |
# If the value is not a string or is empty, keep it as is | |
translated_data[key] = value | |
# Update progress | |
progress((i + 1) / total_items, desc=f"Translated {i + 1}/{total_items} items...") | |
# Create a new file for the translated data | |
output_filename = f"{target_lang}_{model_name.split('/')[-1]}.json" | |
with open(output_filename, 'w', encoding='utf-8') as f: | |
json.dump(translated_data, f, ensure_ascii=False, indent=4) | |
return output_filename | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
# Optionally, raise a Gradio error to show it in the UI | |
raise gr.Error(f"An error occurred during translation: {e}") | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Minecraft Language File Translator | |
Upload your `en_us.json` file, select the model and target language, then click "Translate". | |
""" | |
) | |
with gr.Row(): | |
model_selector = gr.Radio( | |
choices=["facebook/m2m100_418M", "facebook/m2m100_1.2B"], | |
label="Select Model", | |
value="facebook/m2m100_418M" | |
) | |
with gr.Row(): | |
json_upload = gr.File(label="Upload en_us.json") | |
language_dropdown = gr.Dropdown( | |
choices=supported_languages, label="Select Target Language", value="fr" | |
) | |
translate_button = gr.Button("Translate") | |
translated_file = gr.File(label="Download Translated File") | |
translate_button.click( | |
translate_minecraft_json, | |
inputs=[json_upload, language_dropdown, model_selector], | |
outputs=translated_file | |
) | |
if __name__ == "__main__": | |
# Launch the Gradio app | |
demo.launch() | |