Krishna086 commited on
Commit
a6fc19e
·
verified ·
1 Parent(s): ff2efaa

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +10 -11
translation.py CHANGED
@@ -2,9 +2,15 @@ import streamlit as st
2
  from transformers import MarianTokenizer, MarianMTModel
3
  import torch
4
 
 
 
 
 
 
 
5
  @st.cache_resource
6
  def _load_default_model():
7
- model_name = "Helsinki-NLP/opus-mt-en-fr"
8
  tokenizer = MarianTokenizer.from_pretrained(model_name)
9
  model = MarianMTModel.from_pretrained(model_name)
10
  return tokenizer, model
@@ -19,21 +25,17 @@ def load_model(source_lang, target_lang):
19
  model = MarianMTModel.from_pretrained(model_name)
20
  return tokenizer, model
21
  except Exception as e:
22
- st.warning(f"No direct model for {source_lang} to {target_lang}. Falling back to English buffer.")
23
  return _load_default_model()
24
 
25
  @st.cache_data(ttl=3600)
26
  def translate_cached(text, source_lang, target_lang):
27
- src_code = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
28
- "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}.get(source_lang, "en")
29
- tgt_code = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
30
- "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}.get(target_lang, "fr")
31
- tokenizer, model = load_model(src_code, tgt_code)
32
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
33
  with torch.no_grad():
34
  translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
35
  translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
36
- return translated_text if translated_text.strip() else text # Fallback to input if empty
37
 
38
  def translate(text, source_lang, target_lang):
39
  if not text:
@@ -43,6 +45,3 @@ def translate(text, source_lang, target_lang):
43
  except Exception as e:
44
  st.error(f"Translation error: {str(e)}. Using input as fallback.")
45
  return text
46
-
47
- LANGUAGES = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
48
- "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}
 
2
  from transformers import MarianTokenizer, MarianMTModel
3
  import torch
4
 
5
+ LANGUAGES = {
6
+ "en": ("English", "English"), "fr": ("Français", "French"), "es": ("Español", "Spanish"),
7
+ "de": ("Deutsch", "German"), "hi": ("हिन्दी", "Hindi"), "zh": ("中文", "Chinese"),
8
+ "ar": ("العربية", "Arabic"), "ru": ("Русский", "Russian"), "ja": ("日本語", "Japanese")
9
+ }
10
+
11
  @st.cache_resource
12
  def _load_default_model():
13
+ model_name = "Helsinki-NLP/opus-mt-en-hi"
14
  tokenizer = MarianTokenizer.from_pretrained(model_name)
15
  model = MarianMTModel.from_pretrained(model_name)
16
  return tokenizer, model
 
25
  model = MarianMTModel.from_pretrained(model_name)
26
  return tokenizer, model
27
  except Exception as e:
28
+ st.warning(f"No direct model for {source_lang} to {target_lang}. Using en-hi fallback.")
29
  return _load_default_model()
30
 
31
  @st.cache_data(ttl=3600)
32
  def translate_cached(text, source_lang, target_lang):
33
+ tokenizer, model = load_model(source_lang, target_lang)
 
 
 
 
34
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
35
  with torch.no_grad():
36
  translated = model.generate(**inputs, max_length=500, num_beams=2, early_stopping=True)
37
  translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
38
+ return translated_text if translated_text.strip() and len(translated_text.split()) >= 2 else text
39
 
40
  def translate(text, source_lang, target_lang):
41
  if not text:
 
45
  except Exception as e:
46
  st.error(f"Translation error: {str(e)}. Using input as fallback.")
47
  return text