BnT5Sy / app.py
Sabbir772's picture
Update app.py
26b5a1a verified
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import gradio as gr
from peft import PeftModel
# Load model
# model_path = "./banglat5_bn_sy" # path inside Space
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model_path = "Sabbir772/BnT5Sy"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Translation function
def translate(text, source_lang):
if source_lang == "Bangla":
prefix = "<BN>"
elif source_lang == "Sylheti":
prefix = "<SY>"
else:
return "Invalid language selected."
input_text = f"{prefix} {text}"
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=128,
num_beams=4,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Gradio interface
iface = gr.Interface(
fn=translate,
inputs=[
gr.Textbox(label="Input Text"),
gr.Radio(["Bangla", "Sylheti"], label="Source Language")
],
outputs=gr.Textbox(label="Translated Text"),
title="Bangla ↔ Sylheti Dialect Translator (Fine-tuned T5)",
description="Translate between Bangla and Sylheti using a LoRA-finetuned Flan-T5 model."
)
iface.launch()