|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import os |
|
|
|
|
|
|
|
MODEL_PATH = "Gregniuki/pl-en-pl-v2" |
|
|
|
|
|
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") |
|
if HF_AUTH_TOKEN is None: |
|
print("Warning: HF_TOKEN secret not found. Loading model without authentication.") |
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading model and tokenizer from: {MODEL_PATH}") |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_PATH, |
|
token=HF_AUTH_TOKEN, |
|
trust_remote_code=False |
|
) |
|
|
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
MODEL_PATH, |
|
token=HF_AUTH_TOKEN, |
|
trust_remote_code=False |
|
) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
print(f"Using PyTorch model on device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
print("Model and tokenizer loaded successfully.") |
|
|
|
except Exception as e: |
|
print(f"Error loading model/tokenizer: {e}") |
|
|
|
if "401 Client Error" in str(e): |
|
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." |
|
else: |
|
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" |
|
raise gr.Error(error_message) |
|
|
|
|
|
|
|
def translate_text(text_input): |
|
|
|
if not text_input or text_input.strip() == "": |
|
return "[Error] Please enter some text to translate." |
|
|
|
print(f"Received input: '{text_input}'") |
|
|
|
|
|
try: |
|
|
|
inputs = tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) |
|
|
|
|
|
except Exception as e: |
|
print(f"Error during tokenization: {e}") |
|
return f"[Error] Tokenization failed: {e}" |
|
|
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=512, num_beams=4, early_stopping=True |
|
) |
|
output_ids = outputs[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
translation = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
print(f"Generated translation: '{translation}'") |
|
return translation |
|
except Exception as e: |
|
print(f"Error during generation/decoding: {e}") |
|
return f"[Error] Translation generation failed: {e}" |
|
|
|
|
|
|
|
|
|
input_textbox = gr.Textbox(lines=4, label="Input Text (Polish or English)", placeholder="Enter text here...") |
|
output_textbox = gr.Textbox(label="Translation") |
|
interface = gr.Interface( |
|
fn=translate_text, |
|
inputs=input_textbox, |
|
outputs=output_textbox, |
|
title="π΅π± <-> π¬π§ Auto-Detecting ByT5 Translator", |
|
description=f"Translate text between Polish and English.\nModel: {MODEL_PATH}", |
|
article="Enter text and click Submit.", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |