Krishna086's picture
Update translation.py
501abbc verified
raw
history blame
2.85 kB
import streamlit as st
from transformers import MarianTokenizer, MarianMTModel
import torch
LANGUAGES = {
"en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
"de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
"ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
}
@st.cache_resource
def _load_default_model():
model_name = "Helsinki-NLP/opus-mt-en-hi"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
@st.cache_resource
def load_model(source_lang, target_lang):
try:
if source_lang == target_lang:
return _load_default_model()
# Try direct model
model_name = f"Helsinki-NLP/opus-mt-{source_lang}-{target_lang}"
try:
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return tokenizer, model
except Exception:
# Pivot through English
if source_lang != "en" and target_lang != "en":
en_to_target_tokenizer, en_to_target_model = load_model("en", target_lang)
source_to_en_tokenizer, source_to_en_model = load_model(source_lang, "en")
def combined_translate(text):
en_text = source_to_en_tokenizer.decode(source_to_en_model.generate(**source_to_en_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
return en_to_target_tokenizer.decode(en_to_target_model.generate(**en_to_target_tokenizer(en_text, return_tensors="pt", padding=True, truncation=True, max_length=500))[0], skip_special_tokens=True)
class CombinedModel:
def generate(self, **kwargs):
return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True)) for x in kwargs['input_ids']])
return MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi"), CombinedModel()
return _load_default_model()
except Exception:
return _load_default_model()
def translate(text, source_lang, target_lang):
if not text:
return ""
try:
tokenizer, model = load_model(source_lang, target_lang)
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
with torch.no_grad():
translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
return tokenizer.decode(translated[0], skip_special_tokens=True)
except Exception as e:
st.error(f"Translation error: {e}")
return text