import streamlit as st from transformers import MarianTokenizer, MarianMTModel import torch LANGUAGES = { "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"), "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"), "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese") } @st.cache_resource def _load_default_model(): model_name = "Helsinki-NLP/opus-mt-en-hi" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) return tokenizer, model @st.cache_resource def load_model(source_lang, target_lang): try: if source_lang == target_lang: return _load_default_model() model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) return tokenizer, model except Exception as e: st.warning(f"No direct model for {source_lang} to {target_lang}. Using en-hi fallback.") return _load_default_model() @st.cache_data(ttl=3600) def translate_cached(text, source_lang, target_lang): tokenizer, model = load_model(source_lang, target_lang) inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500) with torch.no_grad(): translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True) translated_text = tokenizer.decode(translated[0], skip_special_tokens=True) return translated_text if translated_text.strip() and len(translated_text.split()) >= 2 else text def translate(text, source_lang, target_lang): if not text: return "No text provided." try: return translate_cached(text, source_lang, target_lang) except Exception as e: st.error(f"Translation error: {str(e)}. Using input as fallback.") return text