Krishna086 commited on
Commit
122790b
·
verified ·
1 Parent(s): d04ec9a

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +19 -15
translation.py CHANGED
@@ -1,37 +1,41 @@
1
  import streamlit as st
2
  from transformers import MarianTokenizer, MarianMTModel
 
3
 
4
  @st.cache_resource
5
  def _load_default_model():
6
  model_name = "Helsinki-NLP/opus-mt-en-fr"
7
- return MarianTokenizer.from_pretrained(model_name), MarianMTModel.from_pretrained(model_name)
 
 
8
 
9
  @st.cache_resource
10
  def load_model(src_lang, tgt_lang):
11
  try:
12
  model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
13
- return MarianTokenizer.from_pretrained(model_name), MarianMTModel.from_pretrained(model_name)
 
 
14
  except Exception as e:
15
- st.warning(f"No direct model for {src_lang} to {tgt_lang}. Falling back to en-fr or using input. Error: {str(e)}")
16
  return _load_default_model()
17
 
18
- DEFAULT_TOKENIZER, DEFAULT_MODEL = _load_default_model()
19
-
20
- def translate(text, source_lang, target_lang):
21
- if not text:
22
- return "No text provided."
23
  src_code = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
24
  "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}.get(source_lang, "en")
25
  tgt_code = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
26
  "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}.get(target_lang, "fr")
27
  tokenizer, model = load_model(src_code, tgt_code)
28
- try:
29
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=400)
30
- translated = model.generate(**inputs)
31
- return tokenizer.decode(translated[0], skip_special_tokens=True)
32
- except Exception as e:
33
- st.error(f"Translation generation failed: {str(e)}. Returning input as fallback.")
34
- return text
 
 
35
 
36
  LANGUAGES = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
37
  "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}
 
1
  import streamlit as st
2
  from transformers import MarianTokenizer, MarianMTModel
3
+ import torch
4
 
5
  @st.cache_resource
6
  def _load_default_model():
7
  model_name = "Helsinki-NLP/opus-mt-en-fr"
8
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
9
+ model = MarianMTModel.from_pretrained(model_name)
10
+ return tokenizer, model
11
 
12
  @st.cache_resource
13
  def load_model(src_lang, tgt_lang):
14
  try:
15
  model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
16
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
17
+ model = MarianMTModel.from_pretrained(model_name)
18
+ return tokenizer, model
19
  except Exception as e:
20
+ st.warning(f"No direct model for {src_lang} to {tgt_lang}. Using cached en-fr. Error: {str(e)}")
21
  return _load_default_model()
22
 
23
+ @st.cache_data
24
+ def translate_cached(text, source_lang, target_lang):
 
 
 
25
  src_code = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
26
  "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}.get(source_lang, "en")
27
  tgt_code = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
28
  "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}.get(target_lang, "fr")
29
  tokenizer, model = load_model(src_code, tgt_code)
30
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
31
+ with torch.no_grad():
32
+ translated = model.generate(**inputs, max_length=500)
33
+ return tokenizer.decode(translated[0], skip_special_tokens=True)
34
+
35
+ def translate(text, source_lang, target_lang):
36
+ if not text:
37
+ return "No text provided."
38
+ return translate_cached(text, source_lang, target_lang)
39
 
40
  LANGUAGES = {"English": "en", "French": "fr", "Spanish": "es", "German": "de",
41
  "Hindi": "hi", "Chinese": "zh", "Arabic": "ar", "Russian": "ru", "Japanese": "ja"}