File size: 1,680 Bytes
5216f5c 53decb2 5216f5c 09a3166 53decb2 09a3166 39b15a2 09a3166 5216f5c 26b5a1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
from peft import PeftModel
# Load model
# model_path = "./banglat5_bn_sy" # path inside Space
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model_path = "Sabbir772/BnT5Sy"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Translation function
def translate(text, source_lang):
if source_lang == "Bangla":
prefix = "<BN>"
elif source_lang == "Sylheti":
prefix = "<SY>"
else:
return "Invalid language selected."
input_text = f"{prefix} {text}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=128,
num_beams=4,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Gradio interface
iface = gr.Interface(
fn=translate,
inputs=[
gr.Textbox(label="Input Text"),
gr.Radio(["Bangla", "Sylheti"], label="Source Language")
],
outputs=gr.Textbox(label="Translated Text"),
title="Bangla ↔ Sylheti Dialect Translator (Fine-tuned T5)",
description="Translate between Bangla and Sylheti using a LoRA-finetuned Flan-T5 model."
)
iface.launch() |