Syed-Adnan's picture
Update app.py
34a146f verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# --- UI Configuration ---
st.set_page_config(
page_title="Indic Language Translator",
page_icon="🌏",
layout="centered"
)
# --- Model and Tokenizer Caching ---
@st.cache_resource
def load_model_and_tokenizer():
"""
Loads and caches the translation model and tokenizer.
This ensures the model is loaded only once per session.
"""
model_name = "ai4bharat/IndicTrans2-indic-en-dist-200M"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
return tokenizer, model
except Exception as e:
st.error(f"Failed to load model. Please check your internet connection or the model name. Error: {e}")
return None, None
# --- Translation Function ---
def translate_text(tokenizer, model, text, src_lang_code, tgt_lang_code="__en__"):
"""
Translates a given text from a source language to a target language.
"""
# Prepare the input text in the required format
input_text = f"{src_lang_code} {text} {tgt_lang_code}"
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")
# Generate the translation
with torch.no_grad():
outputs = model.generate(**inputs, num_beams=5, num_return_sequences=1, max_length=1024)
# Decode the generated tokens to get the translated text
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded_output
# --- Main Application ---
def main():
st.title("Indic to English Translator 🌏")
st.markdown(
"""
Translate phrases from various Indian languages into English using the
`IndicTrans2` model from AI4Bharat.
"""
)
st.markdown("---")
# Load the model and tokenizer
tokenizer, model = load_model_and_tokenizer()
if tokenizer and model:
# Language selection dictionary (User-friendly name -> Model code)
# Sourced from the model card
LANGUAGES = {
"Assamese": "__as__", "Bengali": "__bn__", "Bodo": "__brx__",
"Dogri": "__doi__", "English": "__en__", "Goan Konkani": "__gom__",
"Gujarati": "__gu__", "Hindi": "__hi__", "Kannada": "__kn__",
"Kashmiri (Arabic)": "__ksa__", "Kashmiri (Devanagari)": "__ksd__",
"Maithili": "__mai__", "Malayalam": "__ml__", "Manipuri": "__mni__",
"Marathi": "__mr__", "Nepali": "__ne__", "Odia": "__or__",
"Punjabi": "__pa__", "Sanskrit": "__sa__", "Santali": "__sat__",
"Sindhi": "__sd__", "Tamil": "__ta__", "Telugu": "__te__",
"Urdu": "__ur__"
}
# UI Components
source_language_name = st.selectbox(
"Select Source Language:",
options=list(LANGUAGES.keys()),
index=7 # Default to Hindi
)
input_phrase = st.text_area("Enter phrase to translate:", height=100)
if st.button("Translate", type="primary"):
if input_phrase:
with st.spinner("Translating..."):
src_lang_code = LANGUAGES[source_language_name]
translated_text = translate_text(tokenizer, model, input_phrase, src_lang_code)
st.markdown("### Translation Result:")
st.success(translated_text)
else:
st.warning("Please enter a phrase to translate.")
else:
st.error("Application could not be loaded. Please try again later.")
# --- Footer ---
st.markdown("---")
st.markdown("App built by Gemini using a model from [AI4Bharat](https://ai4bharat.iitm.ac.in/).")
if __name__ == "__main__":
main()