File size: 1,680 Bytes
5216f5c
 
 
53decb2
5216f5c
 
09a3166
53decb2
09a3166
 
 
 
39b15a2
09a3166
5216f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b5a1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
from peft import PeftModel

# Load model
# model_path = "./banglat5_bn_sy"  # path inside Space
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

model_path = "Sabbir772/BnT5Sy"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Translation function
def translate(text, source_lang):
    if source_lang == "Bangla":
        prefix = "<BN>"
    elif source_lang == "Sylheti":
        prefix = "<SY>"
    else:
        return "Invalid language selected."

    input_text = f"{prefix} {text}"
    inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=128,
            num_beams=4,
            early_stopping=True
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Gradio interface
iface = gr.Interface(
    fn=translate,
    inputs=[
        gr.Textbox(label="Input Text"),
        gr.Radio(["Bangla", "Sylheti"], label="Source Language")
    ],
    outputs=gr.Textbox(label="Translated Text"),
    title="Bangla ↔ Sylheti Dialect Translator (Fine-tuned T5)",
    description="Translate between Bangla and Sylheti using a LoRA-finetuned Flan-T5 model."
)

iface.launch()