FastAPIMT / src /translate /Translate.py
TiberiuCristianLeon's picture
Update src/translate/Translate.py
95b5309 verified
raw
history blame
2.87 kB
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import src.exception.Exception.Exception as ExceptionCustom
# Use a pipeline as a high-level helper
from transformers import pipeline
METHOD = "TRANSLATE"
def paraphraseTranslateMethod(requestValue: str, model: str):
exception = ExceptionCustom.checkForException(requestValue, METHOD)
if exception:
return "", exception
tokenized_sent_list = sent_tokenize(requestValue)
result_value = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for SENTENCE in tokenized_sent_list:
if model == 'roen':
tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
modelROMENG.to(device)
input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
output = modelROMENG.generate(
input_ids=input_ids.input_ids,
do_sample=True,
max_length=512,
top_k=90,
top_p=0.97,
early_stopping=False
)
result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0]
else:
tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro")
modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro")
modelENGROM.to(device)
input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
output = modelENGROM.generate(
input_ids=input_ids.input_ids,
do_sample=True,
max_length=512,
top_k=90,
top_p=0.97,
early_stopping=False
)
result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
result_value.append(result)
return " ".join(result_value).strip(), model
def gemma(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
prompt = f"Translate this to Romanian using a formal tone. Only return the translation: {requestValue}"
messages = [{"role": "user", "content": f"Translate this text to Romanian using a formal tone. Only return the translated text: {requestValue}"}]
if '/' not in model:
model = 'Gargaz/gemma-2b-romanian-better'
pipe = pipeline(
"text-generation",
model=model,
device=-1,
max_new_tokens=256, # Keep short to reduce verbosity
do_sample=False # Use greedy decoding for determinism
)
output = pipe(messages, num_return_sequences=1, return_full_text=False)
# return output[0]["generated_text"].strip(), model
return output, model