Spaces:
Sleeping
Sleeping
File size: 5,450 Bytes
f45b139 8a629ad 92e1a14 f45b139 8a629ad f45b139 92e1a14 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 8a629ad f45b139 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
import json
import os
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch
import spaces
# --- Model Cache ---
# A global dictionary to cache loaded models and tokenizers to avoid reloading them.
model_cache = {}
def get_model_and_tokenizer(model_name):
"""
Loads and caches the specified model and tokenizer.
Args:
model_name (str): The name of the model to load.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
"""
if model_name in model_cache:
print(f"Using cached model: {model_name}")
return model_cache[model_name]
else:
print(f"Loading model: {model_name}. This may take a moment...")
model = M2M100ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
tokenizer = M2M100Tokenizer.from_pretrained(model_name)
model_cache[model_name] = (model, tokenizer)
print("Model loaded successfully.")
return model, tokenizer
# A list of languages supported by the M2M100 model
supported_languages = [
"af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb",
"cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy",
"ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig",
"ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln",
"lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no",
"ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl",
"so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk",
"ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"
]
@spaces.GPU
def translate_minecraft_json(json_file, target_lang, model_name, progress=gr.Progress()):
"""
Translates the values of a JSON file to a specified language.
Args:
json_file (gradio.File): The uploaded en_us.json file.
target_lang (str): The target language code (e.g., "fr" for French).
model_name (str): The name of the model to use for translation.
progress (gradio.Progress): Gradio progress tracker.
Returns:
str: The file path of the translated JSON file.
"""
if json_file is None:
return None
try:
# Get the selected model and tokenizer
model, tokenizer = get_model_and_tokenizer(model_name)
# Read the uploaded JSON file
with open(json_file.name, 'r', encoding='utf-8') as f:
data = json.load(f)
# Set the source language for the tokenizer
tokenizer.src_lang = "en"
translated_data = {}
total_items = len(data)
# Use the gradio progress tracker
progress(0, desc=f"Starting translation to '{target_lang}' with {model_name}...")
for i, (key, value) in enumerate(data.items()):
if isinstance(value, str) and value.strip(): # Also check if string is not empty
# Tokenize the text
encoded_text = tokenizer(value, return_tensors="pt").to(model.device)
# Generate translation
with torch.no_grad():
generated_tokens = model.generate(
**encoded_text,
forced_bos_token_id=tokenizer.get_lang_id(target_lang)
)
# Decode the translated text
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
translated_data[key] = translated_text
else:
# If the value is not a string or is empty, keep it as is
translated_data[key] = value
# Update progress
progress((i + 1) / total_items, desc=f"Translated {i + 1}/{total_items} items...")
# Create a new file for the translated data
output_filename = f"{target_lang}_{model_name.split('/')[-1]}.json"
with open(output_filename, 'w', encoding='utf-8') as f:
json.dump(translated_data, f, ensure_ascii=False, indent=4)
return output_filename
except Exception as e:
print(f"An error occurred: {e}")
# Optionally, raise a Gradio error to show it in the UI
raise gr.Error(f"An error occurred during translation: {e}")
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown(
"""
# Minecraft Language File Translator
Upload your `en_us.json` file, select the model and target language, then click "Translate".
"""
)
with gr.Row():
model_selector = gr.Radio(
choices=["facebook/m2m100_418M", "facebook/m2m100_1.2B"],
label="Select Model",
value="facebook/m2m100_418M"
)
with gr.Row():
json_upload = gr.File(label="Upload en_us.json")
language_dropdown = gr.Dropdown(
choices=supported_languages, label="Select Target Language", value="fr"
)
translate_button = gr.Button("Translate")
translated_file = gr.File(label="Download Translated File")
translate_button.click(
translate_minecraft_json,
inputs=[json_upload, language_dropdown, model_selector],
outputs=translated_file
)
if __name__ == "__main__":
# Launch the Gradio app
demo.launch()
|