File size: 5,450 Bytes
f45b139
 
 
 
8a629ad
92e1a14
f45b139
8a629ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f45b139
 
 
 
 
 
 
 
 
 
 
 
 
92e1a14
8a629ad
f45b139
 
 
 
 
 
8a629ad
 
f45b139
 
 
 
 
 
 
 
8a629ad
 
 
f45b139
 
 
 
 
 
 
 
 
8a629ad
 
 
f45b139
 
8a629ad
f45b139
8a629ad
f45b139
 
8a629ad
 
 
 
 
f45b139
 
 
 
 
 
 
8a629ad
f45b139
 
8a629ad
 
f45b139
 
 
8a629ad
f45b139
 
 
 
 
 
 
8a629ad
 
 
f45b139
 
 
 
 
 
8a629ad
f45b139
 
8a629ad
 
 
 
 
 
f45b139
 
 
 
 
 
 
 
 
 
8a629ad
f45b139
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import json
import os
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import torch
import spaces

# --- Model Cache ---
# A global dictionary to cache loaded models and tokenizers to avoid reloading them.
model_cache = {}

def get_model_and_tokenizer(model_name):
    """
    Loads and caches the specified model and tokenizer.
    
    Args:
        model_name (str): The name of the model to load.
        
    Returns:
        tuple: A tuple containing the loaded model and tokenizer.
    """
    if model_name in model_cache:
        print(f"Using cached model: {model_name}")
        return model_cache[model_name]
    else:
        print(f"Loading model: {model_name}. This may take a moment...")
        model = M2M100ForConditionalGeneration.from_pretrained(model_name, device_map="auto")
        tokenizer = M2M100Tokenizer.from_pretrained(model_name)
        model_cache[model_name] = (model, tokenizer)
        print("Model loaded successfully.")
        return model, tokenizer

# A list of languages supported by the M2M100 model
supported_languages = [
    "af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb",
    "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy",
    "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig",
    "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln",
    "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no",
    "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl",
    "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk",
    "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"
]

@spaces.GPU
def translate_minecraft_json(json_file, target_lang, model_name, progress=gr.Progress()):
    """
    Translates the values of a JSON file to a specified language.

    Args:
        json_file (gradio.File): The uploaded en_us.json file.
        target_lang (str): The target language code (e.g., "fr" for French).
        model_name (str): The name of the model to use for translation.
        progress (gradio.Progress): Gradio progress tracker.

    Returns:
        str: The file path of the translated JSON file.
    """
    if json_file is None:
        return None

    try:
        # Get the selected model and tokenizer
        model, tokenizer = get_model_and_tokenizer(model_name)

        # Read the uploaded JSON file
        with open(json_file.name, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # Set the source language for the tokenizer
        tokenizer.src_lang = "en"

        translated_data = {}
        total_items = len(data)
        
        # Use the gradio progress tracker
        progress(0, desc=f"Starting translation to '{target_lang}' with {model_name}...")

        for i, (key, value) in enumerate(data.items()):
            if isinstance(value, str) and value.strip(): # Also check if string is not empty
                # Tokenize the text
                encoded_text = tokenizer(value, return_tensors="pt").to(model.device)

                # Generate translation
                with torch.no_grad():
                    generated_tokens = model.generate(
                        **encoded_text,
                        forced_bos_token_id=tokenizer.get_lang_id(target_lang)
                    )

                # Decode the translated text
                translated_text = tokenizer.batch_decode(
                    generated_tokens, skip_special_tokens=True
                )[0]
                translated_data[key] = translated_text
            else:
                # If the value is not a string or is empty, keep it as is
                translated_data[key] = value

            # Update progress
            progress((i + 1) / total_items, desc=f"Translated {i + 1}/{total_items} items...")


        # Create a new file for the translated data
        output_filename = f"{target_lang}_{model_name.split('/')[-1]}.json"
        with open(output_filename, 'w', encoding='utf-8') as f:
            json.dump(translated_data, f, ensure_ascii=False, indent=4)

        return output_filename

    except Exception as e:
        print(f"An error occurred: {e}")
        # Optionally, raise a Gradio error to show it in the UI
        raise gr.Error(f"An error occurred during translation: {e}")


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Minecraft Language File Translator
        Upload your `en_us.json` file, select the model and target language, then click "Translate".
        """
    )
    with gr.Row():
        model_selector = gr.Radio(
            choices=["facebook/m2m100_418M", "facebook/m2m100_1.2B"],
            label="Select Model",
            value="facebook/m2m100_418M"
        )
    with gr.Row():
        json_upload = gr.File(label="Upload en_us.json")
        language_dropdown = gr.Dropdown(
            choices=supported_languages, label="Select Target Language", value="fr"
        )
    translate_button = gr.Button("Translate")
    translated_file = gr.File(label="Download Translated File")

    translate_button.click(
        translate_minecraft_json,
        inputs=[json_upload, language_dropdown, model_selector],
        outputs=translated_file
    )

if __name__ == "__main__":
    # Launch the Gradio app
    demo.launch()