Update translation.py
Browse files- translation.py +4 -2
translation.py
CHANGED
@@ -56,7 +56,7 @@ class CombinedModel:
|
|
56 |
input_ids = kwargs.get('input_ids')
|
57 |
if not input_ids:
|
58 |
return torch.tensor([])
|
59 |
-
return torch.tensor([combined_translate(
|
60 |
|
61 |
# Function to load appropriate translation model with optimized caching
|
62 |
@st.cache_resource
|
@@ -93,8 +93,10 @@ def translate(text, source_lang, target_lang):
|
|
93 |
try:
|
94 |
tokenizer, model = load_model(source_lang, target_lang)
|
95 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
|
|
|
|
96 |
with torch.no_grad():
|
97 |
-
translated = model.generate(**inputs, max_length=1000 if target_lang == "
|
98 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
99 |
return result if result.strip() else text
|
100 |
except Exception as e:
|
|
|
56 |
input_ids = kwargs.get('input_ids')
|
57 |
if not input_ids:
|
58 |
return torch.tensor([])
|
59 |
+
return torch.tensor([combined_translate(self.default_tokenizer.decode(x, skip_special_tokens=True), self.source_lang, self.target_lang, self.default_tokenizer, self.default_model) for x in input_ids])
|
60 |
|
61 |
# Function to load appropriate translation model with optimized caching
|
62 |
@st.cache_resource
|
|
|
93 |
try:
|
94 |
tokenizer, model = load_model(source_lang, target_lang)
|
95 |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=500)
|
96 |
+
if inputs['input_ids'].size(0) > 1: # Ensure single sequence
|
97 |
+
inputs = {k: v[0].unsqueeze(0) for k, v in inputs.items()}
|
98 |
with torch.no_grad():
|
99 |
+
translated = model.generate(**inputs, max_length=1000 if target_lang == "ja" else 500, num_beams=4, early_stopping=True)
|
100 |
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
101 |
return result if result.strip() else text
|
102 |
except Exception as e:
|