File size: 4,563 Bytes
9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 d288580 d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 9a2cc3c d8da1c1 5ada4eb d8da1c1 9a2cc3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TF...
import torch # Or import tensorflow as tf
import os # <--- ADDED: To access environment variables
# --- Configuration ---
# Use the EXACT Hub ID of your PRIVATE model
MODEL_PATH = "Gregniuki/pl-en-pl-v2"
# --- Get Hugging Face Token from Secrets --- # <--- ADDED SECTION
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if HF_AUTH_TOKEN is None:
print("Warning: HF_TOKEN secret not found. Loading model without authentication.")
# Optionally, raise an error if the token is absolutely required:
# raise ValueError("HF_TOKEN secret is missing, cannot load private model.")
# --- END ADDED SECTION ---
# --- Load Model and Tokenizer (do this once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
# --- MODIFIED: Pass the token ---
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN, # <--- ADDED
trust_remote_code=False # Set to True if model requires it
)
# --- MODIFIED: Pass the token ---
# PyTorch
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN, # <--- ADDED
trust_remote_code=False # Set to True if model requires it
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using PyTorch model on device: {device}")
# # TensorFlow (uncomment if using TF)
# from transformers import TFAutoModelForSeq2SeqLM
# import tensorflow as tf
# model = TFAutoModelForSeq2SeqLM.from_pretrained(
# MODEL_PATH,
# token=HF_AUTH_TOKEN, # <--- ADDED
# trust_remote_code=False
# )
# print("Using TensorFlow model.")
# device = "cpu"
model.eval()
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model/tokenizer: {e}")
# Add more specific error handling if needed (e.g., check for 401 Unauthorized)
if "401 Client Error" in str(e):
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
else:
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
raise gr.Error(error_message)
# --- Define the translation function (KEEP AS IS, depending on prefix/no-prefix) ---
def translate_text(text_input): # Or def translate_text(text_input, direction):
# ... (your existing translation logic remains the same) ...
if not text_input or text_input.strip() == "":
return "[Error] Please enter some text to translate."
print(f"Received input: '{text_input}'")
# Tokenize
try:
# PyTorch
inputs = tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
# # TensorFlow
# inputs = tokenizer(text_input, return_tensors="tf", padding=True, truncation=True, max_length=512)
except Exception as e:
print(f"Error during tokenization: {e}")
return f"[Error] Tokenization failed: {e}"
# Generate
try:
# PyTorch
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=512, num_beams=4, early_stopping=True
)
output_ids = outputs[0]
# # TensorFlow
# outputs = model.generate(
# inputs['input_ids'], attention_mask=inputs['attention_mask'],
# max_length=512, num_beams=4, early_stopping=True
# )
# output_ids = outputs[0]
translation = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Generated translation: '{translation}'")
return translation
except Exception as e:
print(f"Error during generation/decoding: {e}")
return f"[Error] Translation generation failed: {e}"
# --- Create Gradio Interface (KEEP AS IS, depending on prefix/no-prefix) ---
# Example for no-prefix model:
input_textbox = gr.Textbox(lines=4, label="Input Text (Polish or English)", placeholder="Enter text here...")
output_textbox = gr.Textbox(label="Translation")
interface = gr.Interface(
fn=translate_text,
inputs=input_textbox,
outputs=output_textbox,
title="π΅π± <-> π¬π§ Auto-Detecting ByT5 Translator",
description=f"Translate text between Polish and English.\nModel: {MODEL_PATH}",
article="Enter text and click Submit.",
allow_flagging="never"
)
# --- Launch the App ---
if __name__ == "__main__":
interface.launch() |