File size: 3,881 Bytes
f465936
34a146f
 
f465936
42d440d
 
34a146f
 
fd4909a
42d440d
f465936
34a146f
42d440d
34a146f
 
 
 
 
 
42d440d
34a146f
 
 
42d440d
34a146f
 
42d440d
34a146f
 
 
 
 
 
 
 
 
 
2ae028d
34a146f
 
 
2ae028d
34a146f
 
 
2ae028d
42d440d
 
34a146f
42d440d
 
34a146f
 
42d440d
 
 
 
34a146f
 
42d440d
34a146f
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4909a
34a146f
 
 
 
 
42d440d
34a146f
 
42d440d
34a146f
 
 
 
 
 
 
 
 
 
 
 
f465936
34a146f
42d440d
34a146f
f465936
42d440d
fd4909a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# --- UI Configuration ---
st.set_page_config(
    page_title="Indic Language Translator",
    page_icon="🌏",
    layout="centered"
)

# --- Model and Tokenizer Caching ---
@st.cache_resource
def load_model_and_tokenizer():
    """
    Loads and caches the translation model and tokenizer.
    This ensures the model is loaded only once per session.
    """
    model_name = "ai4bharat/IndicTrans2-indic-en-dist-200M"
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
        return tokenizer, model
    except Exception as e:
        st.error(f"Failed to load model. Please check your internet connection or the model name. Error: {e}")
        return None, None

# --- Translation Function ---
def translate_text(tokenizer, model, text, src_lang_code, tgt_lang_code="__en__"):
    """
    Translates a given text from a source language to a target language.
    """
    # Prepare the input text in the required format
    input_text = f"{src_lang_code} {text} {tgt_lang_code}"
    
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="pt")

    # Generate the translation
    with torch.no_grad():
        outputs = model.generate(**inputs, num_beams=5, num_return_sequences=1, max_length=1024)

    # Decode the generated tokens to get the translated text
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded_output

# --- Main Application ---
def main():
    st.title("Indic to English Translator 🌏")
    st.markdown(
        """
        Translate phrases from various Indian languages into English using the 
        `IndicTrans2` model from AI4Bharat.
        """
    )
    st.markdown("---")

    # Load the model and tokenizer
    tokenizer, model = load_model_and_tokenizer()

    if tokenizer and model:
        # Language selection dictionary (User-friendly name -> Model code)
        # Sourced from the model card
        LANGUAGES = {
            "Assamese": "__as__", "Bengali": "__bn__", "Bodo": "__brx__",
            "Dogri": "__doi__", "English": "__en__", "Goan Konkani": "__gom__",
            "Gujarati": "__gu__", "Hindi": "__hi__", "Kannada": "__kn__",
            "Kashmiri (Arabic)": "__ksa__", "Kashmiri (Devanagari)": "__ksd__",
            "Maithili": "__mai__", "Malayalam": "__ml__", "Manipuri": "__mni__",
            "Marathi": "__mr__", "Nepali": "__ne__", "Odia": "__or__",
            "Punjabi": "__pa__", "Sanskrit": "__sa__", "Santali": "__sat__",
            "Sindhi": "__sd__", "Tamil": "__ta__", "Telugu": "__te__",
            "Urdu": "__ur__"
        }

        # UI Components
        source_language_name = st.selectbox(
            "Select Source Language:",
            options=list(LANGUAGES.keys()),
            index=7  # Default to Hindi
        )
        
        input_phrase = st.text_area("Enter phrase to translate:", height=100)

        if st.button("Translate", type="primary"):
            if input_phrase:
                with st.spinner("Translating..."):
                    src_lang_code = LANGUAGES[source_language_name]
                    translated_text = translate_text(tokenizer, model, input_phrase, src_lang_code)
                    
                    st.markdown("### Translation Result:")
                    st.success(translated_text)
            else:
                st.warning("Please enter a phrase to translate.")
    else:
        st.error("Application could not be loaded. Please try again later.")

    # --- Footer ---
    st.markdown("---")
    st.markdown("App built by Gemini using a model from [AI4Bharat](https://ai4bharat.iitm.ac.in/).")

if __name__ == "__main__":
    main()