File size: 14,857 Bytes
07fcc31 ed95ef1 0e5692c 07fcc31 e0f366a 77ecbc1 e0f366a c9ac64e e0f366a 3071d3f e0f366a c9ac64e b39a62e c9ac64e 8643097 e0f366a 07fcc31 e0f366a 07fcc31 c9ac64e 84285e8 f6f9d97 07fcc31 e0f366a 07fcc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
# Define a dummy decorator that does nothing if 'spaces' isn't available
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
# This dummy decorator just returns the original function
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces() # Create an instance of the dummy class
import gradio as gr
import re # Import the regular expression module
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM
import torch # Or import tensorflow as tf
import os
import math
# Requires Gradio version supporting spaces.GPU decorator if running on Spaces
# Might need: from gradio.external import spaces <- if spaces not directly available
#import gradio.external as spaces # Use this import path
from huggingface_hub import hf_hub_download
# --- Configuration ---
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path
MODEL_PATH = "Gregniuki/pl-en-pl" # Use your actual model path
MAX_WORDS_PER_CHUNK = 44 # Define the maximum words per chunk
BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput
# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU detected. Using CUDA.")
else:
device = torch.device("cpu")
print("No GPU detected. Using CPU.")
# --- Get Hugging Face Token from Secrets for Private Models ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID
if HF_AUTH_TOKEN is None:
print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.")
else:
print("HF_TOKEN found. Using token for model loading.")
else:
print(f"Loading model from local path: {MODEL_PATH}")
HF_AUTH_TOKEN = None # Don't use token for local paths
# --- Load Model and Tokenizer (once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN,
trust_remote_code=False
)
# --- Choose the correct model class ---
# PyTorch (most common)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN,
trust_remote_code=False
)
model.to(device) # Move model to the determined device
model.eval() # Set model to evaluation mode
print(f"Using PyTorch model on device: {device}")
# # TensorFlow (uncomment if your model is TF)
# from transformers import TFAutoModelForSeq2SeqLM
# import tensorflow as tf
# model = TFAutoModelForSeq2SeqLM.from_pretrained(
# MODEL_PATH,
# token=HF_AUTH_TOKEN,
# trust_remote_code=False
# )
# # TF device placement is often automatic or managed via strategies
# print("Using TensorFlow model.")
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"FATAL Error loading model/tokenizer: {e}")
if "401 Client Error" in str(e):
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
else:
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
# Raise error to prevent app launch if model loading fails
raise RuntimeError(error_message)
# --- Helper Functions for Chunking ---
def split_long_segment_by_comma_or_fallback(segment, max_words):
"""
Splits a long segment (already known > max_words) primarily by commas,
falling back to simple word splitting if needed.
"""
if not segment or segment.isspace():
return []
# 1. Attempt to split by commas, keeping the comma and trailing whitespace
# re.split splits *after* the pattern. (?<=,) looks behind for a comma. \s* matches trailing whitespace.
comma_parts = re.split(r'(?<=,)\s*', segment)
comma_parts = [p.strip() for p in comma_parts if p.strip()] # Trim and filter empty parts
# If no commas found or splitting yielded strange results, fall back to word splitting
if not comma_parts or (len(comma_parts) == 1 and len(comma_parts[0].split()) > max_words):
# print(f"Debug: Falling back to word split for segment: '{segment[:100]}...'") # Optional debug
# Fallback: Simple word-based chunking
words = segment.split()
segment_chunks = []
current_chunk_words = []
for word in words:
current_chunk_words.append(word)
# If adding the current word makes the chunk too long, finalize the previous words
# and start a new chunk with the current word.
if len(current_chunk_words) > max_words:
# Add the chunk excluding the word that pushed it over
segment_chunks.append(" ".join(current_chunk_words[:-1]))
# Start a new chunk with the word that pushed it over
current_chunk_words = [word]
# Edge case: If the chunk is exactly max_words, finalize it unless it's the very first word.
# This prevents a single chunk from staying at max_words forever if no further breaks are found.
elif len(current_chunk_words) == max_words:
segment_chunks.append(" ".join(current_chunk_words))
current_chunk_words = []
# Add any remaining words
if current_chunk_words:
segment_chunks.append(" ".join(current_chunk_words))
return segment_chunks
# 2. Recombine comma-separated parts, respecting max_words
segment_chunks = []
current_chunk_parts = [] # List to hold comma-separated strings for the current chunk
current_chunk_word_count = 0
for i, part in enumerate(comma_parts):
part_word_count = len(part.split())
# Check if adding this part makes the current chunk exceed max_words.
# Condition `current_chunk_word_count > 0` ensures we don't break before adding the first part.
# If the first part itself is > max_words, the fallback above handles it.
if current_chunk_word_count > 0 and (current_chunk_word_count + part_word_count > max_words):
# Finalize the current chunk (join the collected parts)
segment_chunks.append(" ".join(current_chunk_parts).strip()) # Join with space, trim result
# Start a new chunk with the current part
current_chunk_parts = [part]
current_chunk_word_count = part_word_count
else:
# Add the part to the current chunk
current_chunk_parts.append(part)
current_chunk_word_count += part_word_count
# Add any remaining parts as the last chunk for this segment
if current_chunk_parts:
segment_chunks.append(" ".join(current_chunk_parts).strip())
return segment_chunks
def chunk_sentence(sentence, max_words):
"""
Splits text into chunks based on max words, prioritizing sentence-ending punctuation (. ! ?),
then commas (,) if the chunk is already >= max_words, falling back to word split.
Processes the input line as potentially containing multiple sentences.
"""
if not sentence or sentence.isspace():
return []
all_final_chunks = []
# 1. Split the input line into potential "sentence segments" at . ! ?
# Use regex split with lookbehind to split *after* the punctuation and space.
# This yields segments that end in . ! ? (except possibly the very last segment).
# Example: "Hello world. How are you? And you?" -> ["Hello world.", "How are you?", "And you?"]
# Example: "Part one, part two. Part three." -> ["Part one, part two.", "Part three."]
# Example: "No punctuation here" -> ["No punctuation here"]
sentence_segments = re.split(r'(?<=[.!?])\s*', sentence)
# Filter out empty strings that might result from splitting
sentence_segments = [s.strip() for s in sentence_segments if s.strip()]
# 2. Process each sentence segment
for segment in sentence_segments:
segment_word_count = len(segment.split())
if segment_word_count <= max_words:
# Segment is short enough, add directly
all_final_chunks.append(segment)
else:
# Segment is too long, apply comma splitting or fallback word splitting
comma_based_chunks = split_long_segment_by_comma_or_fallback(segment, max_words)
all_final_chunks.extend(comma_based_chunks)
# Ensure no empty strings sneak through at the end
return [chunk for chunk in all_final_chunks if chunk.strip()]
# --- Define the BATCH translation function ---
# Add GPU decorator for Spaces (adjust duration if needed)
@spaces.GPU
def translate_batch(text_input):
"""
Translates multi-line input text using batching and sentence chunking.
Assumes auto-detection of language direction (no prefixes).
"""
if not text_input or text_input.strip() == "":
return "[Error] Please enter some text to translate."
print(f"Received input block for batch translation.")
# 1. Split input into potential sentences (lines) and clean
# Then chunk each line using the sophisticated chunk_sentence function
lines = [line.strip() for line in text_input.splitlines() if line.strip()]
if not lines:
return "[Info] No valid text lines found in input."
# 2. Chunk lines using the new logic
all_chunks = []
for line in lines:
# Process each line as a potential multi-sentence block for chunking
line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK)
all_chunks.extend(line_chunks)
if not all_chunks:
return "[Info] No text chunks generated after processing input."
print(f"Processing {len(all_chunks)} chunks in batches...")
# 3. Process chunks in batches
all_translations = []
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)
for i in range(num_batches):
batch_start = i * BATCH_SIZE
batch_end = batch_start + BATCH_SIZE
batch_chunks = all_chunks[batch_start:batch_end]
print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)")
# Tokenize the batch
try:
inputs = tokenizer(
batch_chunks,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
).to(device)
max_length = 1024 # your specified model max length
max_input_length = inputs["input_ids"].shape[1]
max_new_tokens = min(int(max_input_length * 1.2), max_length)
print(f"Tokenized input (max_length={max_length})")
for i, (text, input_ids) in enumerate(zip(batch_chunks, inputs["input_ids"])):
print(f" Input {i + 1}: {len(input_ids)} tokens")
print(f" Chunk {i + 1}: {repr(text)}...") # Print first 100 chars to keep output manageableu
for idx, ids in enumerate(inputs["input_ids"]):
print(f" Input {idx+1}: {len(ids)} tokens")
except Exception as e:
print(f"Error during batch tokenization: {e}")
return "[Error] Tokenization failed for a batch."
# Generate translations for the batch
try:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=4,
# no_repeat_ngram_size=3,
early_stopping=False,
return_dict_in_generate=True,
output_scores=True
)
print(f" Generation completed with max_new_tokens={max_new_tokens}")
sequences = outputs.sequences
for idx, seq in enumerate(sequences):
print(f" Output {idx+1}: {len(seq)} tokens")
batch_translations = tokenizer.batch_decode(sequences, skip_special_tokens=True)
all_translations.extend(batch_translations)
except Exception as e:
print(f"Error during batch generation/decoding: {e}")
return "[Error] Translation generation failed for a batch."
# 4. Join translated chunks back together
# Simple join with newline. The chunking logic aims to keep sentences/clauses together,
# so joining by newline should preserve the overall structure reasonably well,
# though it might not exactly match the original line breaks if chunking occurred within an original line.
final_output = "\n".join(all_translations)
print("Batch translation finished.")
return final_output
# --- Create Gradio Interface for Batch Translation ---
input_textbox = gr.Textbox(
lines=10, # Allow more lines for batch input
label="Input Text (Polish or English - Enter multiple lines/sentences)",
placeholder=f"Enter text here. Longer sentences/lines will be split into chunks (max {MAX_WORDS_PER_CHUNK} words) prioritizing . ! ? and , breaks."
)
output_textbox = gr.Textbox(label="Translation Output", lines=10)
# Interface definition
interface = gr.Interface(
fn=translate_batch, # Use the batch function
inputs=input_textbox,
outputs=output_textbox,
title="π΅π± <-> π¬π§ Batch ByT5 Translator (Auto-Detect, Smart Chunking)",
description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is automatically split into chunks of max {MAX_WORDS_PER_CHUNK} words, prioritizing breaks at . ! ? and ,",
article="Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. The entire input box content is split into potential 'sentence segments' using . ! ? as delimiters.\n2. Each segment is checked for word count.\n3. If a segment is <= {MAX_WORDS_PER_CHUNK} words, it's treated as a single chunk.\n4. If a segment is > {MAX_WORDS_PER_CHUNK} words, it's further split internally using commas (,) as preferred break points.\n5. If a long segment has no commas, or comma splitting isn't sufficient, it falls back to breaking purely by word count near {MAX_WORDS_PER_CHUNK} to avoid excessively long chunks.\n6. These final chunks are batched and translated.",
allow_flagging="never"
)
# --- Launch the App ---
if __name__ == "__main__":
# Set share=True for a public link if running locally, not needed on Spaces
interface.launch() |