import streamlit as st from transformers import MarianTokenizer, MarianMTModel import torch LANGUAGES = { "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"), "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"), "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese") } # Cache resource to load a specific translation model pair @st.cache_resource def _load_model_pair(source_lang, target_lang): try: model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) return tokenizer, model except Exception: return None, None # Cache resource to load all possible model combinations @st.cache_resource def _load_all_models(): models = {} for src in LANGUAGES.keys(): for tgt in LANGUAGES.keys(): if src != tgt: models[(src, tgt)] = _load_model_pair(src, tgt) return models all_models = _load_all_models() # Define combined_translate outside load_model with explicit parameters def combined_translate(text, source_lang, target_lang, default_tokenizer, default_model): with torch.no_grad(): if source_lang != "en": src_to_en_tokenizer, src_to_en_model = all_models.get((source_lang, "en"), (default_tokenizer, default_model)) en_text = src_to_en_tokenizer.decode(src_to_en_model.generate(**src_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True) else: en_text = text if target_lang != "en": en_to_tgt_tokenizer, en_to_tgt_model = all_models.get(("en", target_lang), (default_tokenizer, default_model)) return en_to_tgt_tokenizer.decode(en_to_tgt_model.generate(**en_to_tgt_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=1000))[0], skip_special_tokens=True) return en_text # Class to handle combined translation through English pivot class CombinedModel: def __init__(self, source_lang, target_lang, default_tokenizer, default_model): self.source_lang = source_lang self.target_lang = target_lang self.default_tokenizer = default_tokenizer self.default_model = default_model def generate(self, **kwargs): input_ids = kwargs.get('input_ids') if not input_ids: return torch.tensor([]) return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids]) # Function to load appropriate translation model with optimized caching @st.cache_resource def load_model(source_lang, target_lang): if source_lang == target_lang: return _load_default_model() model_key = (source_lang, target_lang) tokenizer_model_pair = all_models.get(model_key) if tokenizer_model_pair and tokenizer_model_pair[0] and tokenizer_model_pair[1]: return tokenizer_model_pair # Use simplified pivot through English with CombinedModel default_tokenizer, default_model = _load_default_model() return default_tokenizer, CombinedModel(source_lang, target_lang, default_tokenizer, default_model) # Cache resource to load default translation model @st.cache_resource def _load_default_model(): model_name = "Helsinki-NLP/opus-mt-en-hi" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) return tokenizer, model # Cache translation results to improve speed @st.cache_data def translate(text, source_lang, target_lang): if not text: return "" try: tokenizer, model = load_model(source_lang, target_lang) inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500) with torch.no_grad(): translated = model.generate(**inputs, max_length=1000 if target_lang == "hi" else 500, num_beams=6 if target_lang == "hi" else 4, early_stopping=True) result = tokenizer.decode(translated[0], skip_special_tokens=True) return result if result.strip() else text except Exception as e: st.error(f"Translation error: {e}") return text