File size: 4,603 Bytes
2e463ee
d120873
275593f
d120873
b9f4a14
37ded96
 
d120873
 
ad9a50c
d120873
d22cb09
2e463ee
 
d22cb09
 
 
d120873
d22cb09
 
95b5309
d22cb09
b8db721
95b5309
 
 
49ae858
d22cb09
 
 
 
 
 
 
 
 
49ae858
95b5309
 
 
d22cb09
 
 
 
 
 
 
 
 
 
 
 
37ded96
 
2e5046d
fb22ed5
39ababd
 
088720e
 
2afdfd3
 
fb5c2d4
 
2e5046d
 
 
dce1a40
 
fb5c2d4
 
 
 
 
 
 
275593f
 
 
c55d8e2
275593f
d550028
b41b21e
c55d8e2
 
73d63a7
b41b21e
 
 
 
b48c77c
 
275593f
c55d8e2
 
 
 
 
 
 
 
 
 
 
fb5c2d4
 
c55d8e2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import nltk
from nltk.tokenize import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import torch
import src.exception.Exception as ExceptionCustom
# Use a pipeline as a high-level helper
from transformers import pipeline

METHOD = "TRANSLATE"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def paraphraseTranslateMethod(requestValue: str, model: str):
    nltk.download('punkt')
    nltk.download('punkt_tab')
    exception = ExceptionCustom.checkForException(requestValue, METHOD)
    if exception:
        return "", exception

    tokenized_sent_list = sent_tokenize(requestValue)
    result_value = []
    
    for SENTENCE in tokenized_sent_list:
        if model == 'roen':
            tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en")
            modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en")
            modelROMENG.to(device)
            input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device)
            output = modelROMENG.generate(
                input_ids=input_ids.input_ids,
                do_sample=True,
                max_length=512,
                top_k=90,
                top_p=0.97,
                early_stopping=False
            )
            result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0]
        else:
            tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro")
            modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro")
            modelENGROM.to(device)
            input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device)
            output = modelENGROM.generate(
                input_ids=input_ids.input_ids,
                do_sample=True,
                max_length=512,
                top_k=90,
                top_p=0.97,
                early_stopping=False
            )
            result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0]
        result_value.append(result)

    return " ".join(result_value).strip(), model

def gemma(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
    requestValue = requestValue.replace('\n', ' ')
    prompt = f"Translate this to Romanian using a formal tone, responding only with the translated text: {requestValue}"
    messages = [{"role": "user", "content": f"Translate this text to Romanian: {requestValue}"}]
    if '/' not in model:
        model = 'Gargaz/gemma-2b-romanian-better'
    # limit max_new_tokens to 150% of the requestValue  
    max_new_tokens = int(len(requestValue) + len(requestValue) * 0.5)
    try:
        pipe = pipeline(
        "text-generation",
        model=model,
        device=-1,
        max_new_tokens=max_new_tokens,   # Keep short to reduce verbosity
        do_sample=False  # Use greedy decoding for determinism
        )
        output = pipe(messages, num_return_sequences=1, return_full_text=False)
        generated_text = output[0]["generated_text"]
        result = generated_text.split('\n', 1)[0] if '\n' in generated_text else generated_text
        return result.strip()
    except Exception as error:
        return error

def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
    # Load model directly
    model_name = model if '/' in model else 'Gargaz/gemma-2b-romanian-better'
    # limit max_new_tokens to 150% of the requestValue  
    prompt = f"Translate this text to Romanian: {requestValue}"
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    input_ids = tokenizer.encode(requestValue, add_special_tokens=True)
    num_tokens = len(input_ids)
    # Estimate output length (e.g., 50% longer)
    max_new_tokens = int(num_tokens * 1.5)
    max_new_tokens += max_new_tokens % 2  # ensure it's even
       
    messages = [{"role": "user", "content": prompt}]
    
    try:
        inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(device)
    
        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
        response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
        result = response.split('\n', 1)[0] if '\n' in response else response 
        return result.strip()
    except Exception as error:
        return error