FastAPIMT / src /Translate.py
TiberiuCristianLeon's picture
Update src/Translate.py
b9f4a14 verified
import nltk
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import src.exception.Exception as ExceptionCustom
# Use a pipeline as a high-level helper
from transformers import pipeline
METHOD = "TRANSLATE"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def paraphraseTranslateMethod(requestValue: str, model: str):
nltk.download('punkt')
nltk.download('punkt_tab')
exception = ExceptionCustom.checkForException(requestValue, METHOD)
if exception:
return "", exception
tokenized_sent_list = sent_tokenize(requestValue)
result_value = []
for SENTENCE in tokenized_sent_list:
if model == 'roen':
tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
modelROMENG.to(device)
input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
output = modelROMENG.generate(
input_ids=input_ids.input_ids,
do_sample=True,
max_length=512,
top_k=90,
top_p=0.97,
early_stopping=False
)
result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0]
else:
tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro")
modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro")
modelENGROM.to(device)
input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
output = modelENGROM.generate(
input_ids=input_ids.input_ids,
do_sample=True,
max_length=512,
top_k=90,
top_p=0.97,
early_stopping=False
)
result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
result_value.append(result)
return " ".join(result_value).strip(), model
def gemma(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
requestValue = requestValue.replace('\n', ' ')
prompt = f"Translate this to Romanian using a formal tone, responding only with the translated text: {requestValue}"
messages = [{"role": "user", "content": f"Translate this text to Romanian: {requestValue}"}]
if '/' not in model:
model = 'Gargaz/gemma-2b-romanian-better'
# limit max_new_tokens to 150% of the requestValue
max_new_tokens = int(len(requestValue) + len(requestValue) * 0.5)
try:
pipe = pipeline(
"text-generation",
model=model,
device=-1,
max_new_tokens=max_new_tokens, # Keep short to reduce verbosity
do_sample=False # Use greedy decoding for determinism
)
output = pipe(messages, num_return_sequences=1, return_full_text=False)
generated_text = output[0]["generated_text"]
result = generated_text.split('\n', 1)[0] if '\n' in generated_text else generated_text
return result.strip()
except Exception as error:
return error
def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
# Load model directly
model_name = model if '/' in model else 'Gargaz/gemma-2b-romanian-better'
# limit max_new_tokens to 150% of the requestValue
prompt = f"Translate this text to Romanian: {requestValue}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
input_ids = tokenizer.encode(requestValue, add_special_tokens=True)
num_tokens = len(input_ids)
# Estimate output length (e.g., 50% longer)
max_new_tokens = int(num_tokens * 1.5)
max_new_tokens += max_new_tokens % 2 # ensure it's even
messages = [{"role": "user", "content": prompt}]
try:
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device)
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
result = response.split('\n', 1)[0] if '\n' in response else response
return result.strip()
except Exception as error:
return error