|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import gradio as gr |
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = "Sabbir772/BnT5Sy" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
model.eval() |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
def translate(text, source_lang): |
|
if source_lang == "Bangla": |
|
prefix = "<BN>" |
|
elif source_lang == "Sylheti": |
|
prefix = "<SY>" |
|
else: |
|
return "Invalid language selected." |
|
|
|
input_text = f"{prefix} {text}" |
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
max_length=128, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=translate, |
|
inputs=[ |
|
gr.Textbox(label="Input Text"), |
|
gr.Radio(["Bangla", "Sylheti"], label="Source Language") |
|
], |
|
outputs=gr.Textbox(label="Translated Text"), |
|
title="Bangla ↔ Sylheti Dialect Translator (Fine-tuned T5)", |
|
description="Translate between Bangla and Sylheti using a LoRA-finetuned Flan-T5 model." |
|
) |
|
|
|
iface.launch() |