translate / app.py
Gregniuki's picture
Update app.py
43fe904 verified
# --- START OF FILE app.py ---
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
# Define a dummy decorator that does nothing if 'spaces' isn't available
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
# This dummy decorator just returns the original function
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces() # Create an instance of the dummy class
import gradio as gr
import re # Import the regular expression module
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM
import torch # Or import tensorflow as tf
import os
import math
from huggingface_hub import hf_hub_download
# --- Configuration ---
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path
MODEL_PATH = "Gregniuki/pl-en-pl-v2" # Use your actual model path
MAX_WORDS_PER_CHUNK = 128 # Define the maximum words per chunk
BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput
# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU detected. Using CUDA.")
else:
device = torch.device("cpu")
print("No GPU detected. Using CPU.")
# --- Get Hugging Face Token from Secrets for Private Models ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID
if HF_AUTH_TOKEN is None:
print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.")
else:
print("HF_TOKEN found. Using token for model loading.")
else:
print(f"Loading model from local path: {MODEL_PATH}")
HF_AUTH_TOKEN = None # Don't use token for local paths
# --- Load Model and Tokenizer (once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN,
trust_remote_code=False
)
# --- Choose the correct model class ---
# PyTorch (most common)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN,
trust_remote_code=False
)
model.to(device) # Move model to the determined device
model.eval() # Set model to evaluation mode
print(f"Using PyTorch model on device: {device}")
# # TensorFlow (uncomment if your model is TF)
# from transformers import TFAutoModelForSeq2SeqLM
# import tensorflow as tf
# model = TFAutoModelForSeq2SeqLM.from_pretrained(
# MODEL_PATH,
# token=HF_AUTH_TOKEN,
# trust_remote_code=False
# )
# # TF device placement is often automatic or managed via strategies
# print("Using TensorFlow model.")
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"FATAL Error loading model/tokenizer: {e}")
if "401 Client Error" in str(e):
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
else:
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
# Raise error to prevent app launch if model loading fails
raise RuntimeError(error_message)
# --- Helper Functions for Chunking ---
def chunk_sentence(sentence, max_words):
"""
Splits a sentence (or line of text) into chunks ONLY if it exceeds max_words.
If splitting is needed, it prioritizes splitting *after* sentence-ending
punctuation (. ! ?) or commas (,) found within the first `max_words`.
It looks for the *last* such punctuation within that limit.
If no suitable punctuation is found, it splits strictly at `max_words`.
"""
if not sentence or sentence.isspace():
return []
sentence = sentence.strip() # Ensure no leading/trailing whitespace
words = sentence.split()
word_count = len(words)
# If the sentence is short enough, return it as a single chunk
if word_count <= max_words:
return [sentence]
# If the sentence is too long, proceed with chunking
chunks = []
current_word_index = 0
while current_word_index < word_count:
# Determine the end index for the current potential chunk (non-inclusive)
potential_end_word_index = min(current_word_index + max_words, word_count)
# Assume we split at the max_words limit initially
actual_end_word_index = potential_end_word_index
# Check if we need to look for punctuation (i.e., if this chunk would be exactly max_words
# and there's more text remaining, or if the remaining text itself is longer than max_words)
# This check ensures we don't unnecessarily truncate if the remaining part is short.
if potential_end_word_index < word_count:
# Search backwards from the word *before* the potential end index
# down to the start of the current segment for punctuation.
best_punctuation_split_index = -1
for i in range(potential_end_word_index - 1, current_word_index, -1):
# Check if the word at index 'i' ends with the desired punctuation
if words[i].endswith(('.', '!', '?')):
best_punctuation_split_index = i + 1 # Split *after* this word
break # Found the last suitable punctuation in the range
# If we found a punctuation split point, use it
if best_punctuation_split_index > current_word_index: # Ensure it's a valid index within the current segment
actual_end_word_index = best_punctuation_split_index
# Else: No suitable punctuation found, stick with potential_end_word_index (split at max_words limit)
# Safety check: Prevent creating an empty chunk if the split point is the same as the start
# This can happen if the first word itself is very long or under unusual circumstances.
# Force consuming at least one word if we are not at the end.
if actual_end_word_index <= current_word_index and current_word_index < word_count:
actual_end_word_index = current_word_index + 1
print(f"Warning: Split point adjustment needed. Forced split after word index {current_word_index}.")
# Extract the chunk words and join them
chunk_words = words[current_word_index:actual_end_word_index]
if chunk_words: # Ensure we don't add empty strings
chunks.append(" ".join(chunk_words))
# Update the starting index for the next chunk
current_word_index = actual_end_word_index
# Basic infinite loop prevention (should not be necessary with correct logic but safe)
if current_word_index == word_count and len(chunks) > 0: # Normal exit condition
break
if current_word_index < word_count and actual_end_word_index <= current_word_index :
print(f"ERROR: Chunking loop failed to advance. Aborting chunking for this sentence.")
# Return partially chunked sentence or handle error appropriately
# For simplicity, we might return the chunks found so far plus the rest unsplit
remaining_words = words[current_word_index:]
if remaining_words:
chunks.append(" ".join(remaining_words))
break # Exit loop
return [chunk for chunk in chunks if chunk] # Final filter for empty strings
# --- Define the BATCH translation function ---
# Add GPU decorator for Spaces (adjust duration if needed)
@spaces.GPU
def translate_batch(
text_input,
num_beams,
do_sample,
temperature,
top_k,
top_p
):
"""
Translates multi-line input text using batching and sentence chunking.
Assumes auto-detection of language direction (no prefixes).
Uses the updated chunking logic and supports generation controls.
"""
if not text_input or text_input.strip() == "":
return "[Error] Please enter some text to translate."
print(f"Received input block for batch translation.")
# 1. Split input into lines and clean
lines = [line.strip() for line in text_input.splitlines() if line.strip()]
if not lines:
return "[Info] No valid text lines found in input."
# 2. Chunk each line individually using the new logic
all_chunks = []
for line in lines:
# Apply the new chunking logic to each line
line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK)
all_chunks.extend(line_chunks)
if not all_chunks:
return "[Info] No text chunks generated after processing input."
print(f"Processing {len(all_chunks)} chunks in batches with parameters:")
print(f" num_beams: {num_beams}")
print(f" do_sample: {do_sample}")
print(f" temperature: {temperature}")
print(f" top_k: {top_k}")
print(f" top_p: {top_p}")
# 3. Process chunks in batches
all_translations = []
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)
for i in range(num_batches):
batch_start = i * BATCH_SIZE
batch_end = batch_start + BATCH_SIZE
batch_chunks = all_chunks[batch_start:batch_end]
print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)")
# Tokenize the batch
try:
inputs = tokenizer(
batch_chunks,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024 # Model's max input length
).to(device)
# Estimate appropriate max_new_tokens based on input length
max_input_length = inputs["input_ids"].shape[1]
max_new_tokens = min(int(max_input_length * 1.2) + 20, 1024) # Increased multiplier for safety and added a buffer
print(f"Tokenized input (batch max length={max_input_length}), setting max_new_tokens={max_new_tokens}")
except Exception as e:
print(f"Error during batch tokenization: {e}")
all_translations.append(f"[Error tokenizing batch {i+1}]")
continue # Skip to next batch or break
# Prepare generation arguments
generation_kwargs = {
**inputs,
"max_new_tokens": max_new_tokens,
"early_stopping": True, # Usually good for translation
}
# Apply decoding strategy parameters
if do_sample:
generation_kwargs["do_sample"] = True
generation_kwargs["temperature"] = temperature
generation_kwargs["top_k"] = top_k if top_k > 0 else None # 0 means disable top_k
generation_kwargs["top_p"] = top_p if top_p > 0 and top_p < 1.0 else None # 0 or 1 means disable top_p
# When sampling, num_beams is typically 1. We force it here to avoid conflicts.
generation_kwargs["num_beams"] = 1
print(f" Using Sampling: temp={temperature}, top_k={top_k}, top_p={top_p}")
else:
generation_kwargs["do_sample"] = False # Ensure greedy/beam search
generation_kwargs["num_beams"] = num_beams
print(f" Using {'Greedy' if num_beams == 1 else 'Beam Search'}: num_beams={num_beams}")
# Generate translations for the batch
try:
with torch.no_grad():
outputs = model.generate(**generation_kwargs)
print(f" Generation completed for batch {i+1}")
batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
all_translations.extend(batch_translations)
except Exception as e:
print(f"Error during batch generation/decoding: {e}")
error_msg = f"[Error translating batch {i+1}]"
all_translations.extend([error_msg] * len(batch_chunks))
# 4. Join translated chunks back together
final_output = "\n".join(all_translations)
print("Batch translation finished.")
return final_output
# --- Create Gradio Interface for Batch Translation ---
input_textbox = gr.Textbox(
lines=10, # Allow more lines for batch input
label="Input Text (Polish or English - Enter multiple lines/sentences)",
placeholder=f"Enter text here. Lines longer than {MAX_WORDS_PER_CHUNK} words will be split, prioritizing breaks after . ! ? , near the limit."
)
output_textbox = gr.Textbox(label="Translation Output", lines=10)
# New generation control components
num_beams_slider = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=4, # Default value for beam search
label="Number of Beams",
info="Controls the breadth of the search during decoding (Beam Search). Set to 1 for Greedy Search. Ignored if 'Enable Sampling' is checked."
)
do_sample_checkbox = gr.Checkbox(
value=False, # Default to beam search
label="Enable Sampling",
info="If checked, enables probabilistic sampling using Temperature, Top-K, and Top-P. Overrides Beam Search (forces Number of Beams to 1)."
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.01, # Finer control for temperature
value=1.0, # Default for sampling
label="Temperature",
info="Controls randomness. Lower=more deterministic, Higher=more random. Only active if 'Enable Sampling' is checked."
)
top_k_slider = gr.Slider(
minimum=0, # 0 usually means disable Top-K
maximum=200,
step=1,
value=50, # Common value
label="Top-K",
info="Number of highest probability vocabulary tokens to consider for sampling. 0 means no Top-K. Only active if 'Enable Sampling' is checked."
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01, # Finer control for Top-P
value=0.9, # Common value for nucleus sampling
label="Top-P (Nucleus Sampling)",
info="Fraction of highest probability tokens to consider whose cumulative probability exceeds P. Only active if 'Enable Sampling' is checked."
)
# Interface definition
interface = gr.Interface(
fn=translate_batch, # Use the batch function
inputs=[
input_textbox,
num_beams_slider,
do_sample_checkbox,
temperature_slider,
top_k_slider,
top_p_slider
],
outputs=output_textbox,
title="πŸ‡΅πŸ‡± <-> πŸ‡¬πŸ‡§ Batch ByT5 Translator (Auto-Detect, Smart Chunking with Generation Controls)",
description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is processed line by line. Lines longer than {MAX_WORDS_PER_CHUNK} words are split into chunks. You can adjust generation parameters below.",
article=f"Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. Each line you enter is processed independently.\n2. If a line contains {MAX_WORDS_PER_CHUNK} words or fewer, it is translated as a single unit.\n3. If a line contains more than {MAX_WORDS_PER_CHUNK} words, it is split into smaller chunks.\n4. When splitting, the algorithm looks for the last punctuation mark (. ! ? ,) within the first {MAX_WORDS_PER_CHUNK} words to use as a natural break point.\n5. If no suitable punctuation is found in that range, the line is split exactly at the {MAX_WORDS_PER_CHUNK}-word limit.\n6. This process repeats for the remainder of the line until all parts are below the word limit.\n7. These final chunks are then translated in batches.\n\nGeneration Controls:\n- **Number of Beams**: Higher values (e.g., 4-10) can produce more fluent but less diverse output. Set to 1 for greedy decoding.\n- **Enable Sampling**: Activates probabilistic generation. When checked, the model will use Temperature, Top-K, and Top-P, and 'Number of Beams' will be forced to 1.\n- **Temperature**: Controls randomness. A value of 1.0 means no change to probabilities. Lower values (e.g., 0.7) make the output more focused and less random. Higher values (e.g., 1.5) increase randomness and creativity.\n- **Top-K**: The model considers only the 'K' most likely next tokens. A value of 0 effectively disables it. Useful for limiting very unlikely predictions.\n- **Top-P (Nucleus Sampling)**: The model considers the smallest set of tokens whose cumulative probability exceeds 'P'. This dynamically adjusts the number of tokens considered based on the probability distribution, making it more flexible than Top-K.",
allow_flagging="never"
)
# --- Launch the App ---
if __name__ == "__main__":
# Set share=True for a public link if running locally, not needed on Spaces
interface.launch()
# --- END OF FILE app.py ---