Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
# --- UI Configuration --- | |
st.set_page_config( | |
page_title="Indic Language Translator", | |
page_icon="π", | |
layout="centered" | |
) | |
# --- Model and Tokenizer Caching --- | |
def load_model_and_tokenizer(): | |
""" | |
Loads and caches the translation model and tokenizer. | |
This ensures the model is loaded only once per session. | |
""" | |
model_name = "ai4bharat/IndicTrans2-indic-en-dist-200M" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Failed to load model. Please check your internet connection or the model name. Error: {e}") | |
return None, None | |
# --- Translation Function --- | |
def translate_text(tokenizer, model, text, src_lang_code, tgt_lang_code="__en__"): | |
""" | |
Translates a given text from a source language to a target language. | |
""" | |
# Prepare the input text in the required format | |
input_text = f"{src_lang_code} {text} {tgt_lang_code}" | |
# Tokenize the input text | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Generate the translation | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, num_beams=5, num_return_sequences=1, max_length=1024) | |
# Decode the generated tokens to get the translated text | |
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return decoded_output | |
# --- Main Application --- | |
def main(): | |
st.title("Indic to English Translator π") | |
st.markdown( | |
""" | |
Translate phrases from various Indian languages into English using the | |
`IndicTrans2` model from AI4Bharat. | |
""" | |
) | |
st.markdown("---") | |
# Load the model and tokenizer | |
tokenizer, model = load_model_and_tokenizer() | |
if tokenizer and model: | |
# Language selection dictionary (User-friendly name -> Model code) | |
# Sourced from the model card | |
LANGUAGES = { | |
"Assamese": "__as__", "Bengali": "__bn__", "Bodo": "__brx__", | |
"Dogri": "__doi__", "English": "__en__", "Goan Konkani": "__gom__", | |
"Gujarati": "__gu__", "Hindi": "__hi__", "Kannada": "__kn__", | |
"Kashmiri (Arabic)": "__ksa__", "Kashmiri (Devanagari)": "__ksd__", | |
"Maithili": "__mai__", "Malayalam": "__ml__", "Manipuri": "__mni__", | |
"Marathi": "__mr__", "Nepali": "__ne__", "Odia": "__or__", | |
"Punjabi": "__pa__", "Sanskrit": "__sa__", "Santali": "__sat__", | |
"Sindhi": "__sd__", "Tamil": "__ta__", "Telugu": "__te__", | |
"Urdu": "__ur__" | |
} | |
# UI Components | |
source_language_name = st.selectbox( | |
"Select Source Language:", | |
options=list(LANGUAGES.keys()), | |
index=7 # Default to Hindi | |
) | |
input_phrase = st.text_area("Enter phrase to translate:", height=100) | |
if st.button("Translate", type="primary"): | |
if input_phrase: | |
with st.spinner("Translating..."): | |
src_lang_code = LANGUAGES[source_language_name] | |
translated_text = translate_text(tokenizer, model, input_phrase, src_lang_code) | |
st.markdown("### Translation Result:") | |
st.success(translated_text) | |
else: | |
st.warning("Please enter a phrase to translate.") | |
else: | |
st.error("Application could not be loaded. Please try again later.") | |
# --- Footer --- | |
st.markdown("---") | |
st.markdown("App built by Gemini using a model from [AI4Bharat](https://ai4bharat.iitm.ac.in/).") | |
if __name__ == "__main__": | |
main() | |