Krishna086 commited on
Commit
17b4050
·
verified ·
1 Parent(s): 7708c36

Update translation.py

Browse files
Files changed (1) hide show
  1. translation.py +4 -2
translation.py CHANGED
@@ -56,7 +56,7 @@ class CombinedModel:
56
  input_ids = kwargs.get('input_ids')
57
  if not input_ids:
58
  return torch.tensor([])
59
- return torch.tensor([combined_translate(tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])
60
 
61
  # Function to load appropriate translation model with optimized caching
62
  @st.cache_resource
@@ -93,8 +93,10 @@ def translate(text, source_lang, target_lang):
93
  try:
94
  tokenizer, model = load_model(source_lang, target_lang)
95
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
 
 
96
  with torch.no_grad():
97
- translated = model.generate(**inputs, max_length=1000 if target_lang == "hi" else 500, num_beams=4, early_stopping=True) # Reduced to 4 beams for speed
98
  result = tokenizer.decode(translated[0], skip_special_tokens=True)
99
  return result if result.strip() else text
100
  except Exception as e:
 
56
  input_ids = kwargs.get('input_ids')
57
  if not input_ids:
58
  return torch.tensor([])
59
+ return torch.tensor([combined_translate(self.default_tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])
60
 
61
  # Function to load appropriate translation model with optimized caching
62
  @st.cache_resource
 
93
  try:
94
  tokenizer, model = load_model(source_lang, target_lang)
95
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
96
+ if inputs['input_ids'].size(0) > 1: # Ensure single sequence
97
+ inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
98
  with torch.no_grad():
99
+ translated = model.generate(**inputs, max_length=1000 if target_lang == "ja" else 500, num_beams=4, early_stopping=True)
100
  result = tokenizer.decode(translated[0], skip_special_tokens=True)
101
  return result if result.strip() else text
102
  except Exception as e: