|
try: |
|
import spaces |
|
print("'spaces' module imported successfully.") |
|
except ImportError: |
|
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") |
|
|
|
class DummySpaces: |
|
def GPU(self, *args, **kwargs): |
|
def decorator(func): |
|
|
|
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") |
|
return func |
|
return decorator |
|
spaces = DummySpaces() |
|
|
|
import gradio as gr |
|
import re |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import os |
|
import math |
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
MODEL_PATH = "Gregniuki/pl-en-pl" |
|
MAX_WORDS_PER_CHUNK = 44 |
|
BATCH_SIZE = 8 |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print("GPU detected. Using CUDA.") |
|
else: |
|
device = torch.device("cpu") |
|
print("No GPU detected. Using CPU.") |
|
|
|
|
|
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") |
|
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): |
|
if HF_AUTH_TOKEN is None: |
|
print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.") |
|
else: |
|
print("HF_TOKEN found. Using token for model loading.") |
|
else: |
|
print(f"Loading model from local path: {MODEL_PATH}") |
|
HF_AUTH_TOKEN = None |
|
|
|
|
|
|
|
print(f"Loading model and tokenizer from: {MODEL_PATH}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_PATH, |
|
token=HF_AUTH_TOKEN, |
|
trust_remote_code=False |
|
) |
|
|
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
MODEL_PATH, |
|
token=HF_AUTH_TOKEN, |
|
trust_remote_code=False |
|
) |
|
model.to(device) |
|
model.eval() |
|
print(f"Using PyTorch model on device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Model and tokenizer loaded successfully.") |
|
|
|
except Exception as e: |
|
print(f"FATAL Error loading model/tokenizer: {e}") |
|
if "401 Client Error" in str(e): |
|
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." |
|
else: |
|
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" |
|
|
|
raise RuntimeError(error_message) |
|
|
|
|
|
|
|
|
|
def split_long_segment_by_comma_or_fallback(segment, max_words): |
|
""" |
|
Splits a long segment (already known > max_words) primarily by commas, |
|
falling back to simple word splitting if needed. |
|
""" |
|
if not segment or segment.isspace(): |
|
return [] |
|
|
|
|
|
|
|
comma_parts = re.split(r'(?<=,)\s*', segment) |
|
comma_parts = [p.strip() for p in comma_parts if p.strip()] |
|
|
|
|
|
if not comma_parts or (len(comma_parts) == 1 and len(comma_parts[0].split()) > max_words): |
|
|
|
|
|
words = segment.split() |
|
segment_chunks = [] |
|
current_chunk_words = [] |
|
for word in words: |
|
current_chunk_words.append(word) |
|
|
|
|
|
if len(current_chunk_words) > max_words: |
|
|
|
segment_chunks.append(" ".join(current_chunk_words[:-1])) |
|
|
|
current_chunk_words = [word] |
|
|
|
|
|
elif len(current_chunk_words) == max_words: |
|
segment_chunks.append(" ".join(current_chunk_words)) |
|
current_chunk_words = [] |
|
|
|
|
|
|
|
if current_chunk_words: |
|
segment_chunks.append(" ".join(current_chunk_words)) |
|
|
|
return segment_chunks |
|
|
|
|
|
segment_chunks = [] |
|
current_chunk_parts = [] |
|
current_chunk_word_count = 0 |
|
|
|
for i, part in enumerate(comma_parts): |
|
part_word_count = len(part.split()) |
|
|
|
|
|
|
|
|
|
if current_chunk_word_count > 0 and (current_chunk_word_count + part_word_count > max_words): |
|
|
|
segment_chunks.append(" ".join(current_chunk_parts).strip()) |
|
|
|
current_chunk_parts = [part] |
|
current_chunk_word_count = part_word_count |
|
else: |
|
|
|
current_chunk_parts.append(part) |
|
current_chunk_word_count += part_word_count |
|
|
|
|
|
if current_chunk_parts: |
|
segment_chunks.append(" ".join(current_chunk_parts).strip()) |
|
|
|
return segment_chunks |
|
|
|
|
|
def chunk_sentence(sentence, max_words): |
|
""" |
|
Splits text into chunks based on max words, prioritizing sentence-ending punctuation (. ! ?), |
|
then commas (,) if the chunk is already >= max_words, falling back to word split. |
|
Processes the input line as potentially containing multiple sentences. |
|
""" |
|
if not sentence or sentence.isspace(): |
|
return [] |
|
|
|
all_final_chunks = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentence_segments = re.split(r'(?<=[.!?])\s*', sentence) |
|
|
|
|
|
sentence_segments = [s.strip() for s in sentence_segments if s.strip()] |
|
|
|
|
|
for segment in sentence_segments: |
|
segment_word_count = len(segment.split()) |
|
|
|
if segment_word_count <= max_words: |
|
|
|
all_final_chunks.append(segment) |
|
else: |
|
|
|
comma_based_chunks = split_long_segment_by_comma_or_fallback(segment, max_words) |
|
all_final_chunks.extend(comma_based_chunks) |
|
|
|
|
|
return [chunk for chunk in all_final_chunks if chunk.strip()] |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def translate_batch(text_input): |
|
""" |
|
Translates multi-line input text using batching and sentence chunking. |
|
Assumes auto-detection of language direction (no prefixes). |
|
""" |
|
if not text_input or text_input.strip() == "": |
|
return "[Error] Please enter some text to translate." |
|
|
|
print(f"Received input block for batch translation.") |
|
|
|
|
|
|
|
lines = [line.strip() for line in text_input.splitlines() if line.strip()] |
|
if not lines: |
|
return "[Info] No valid text lines found in input." |
|
|
|
|
|
all_chunks = [] |
|
for line in lines: |
|
|
|
line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK) |
|
all_chunks.extend(line_chunks) |
|
|
|
if not all_chunks: |
|
return "[Info] No text chunks generated after processing input." |
|
|
|
print(f"Processing {len(all_chunks)} chunks in batches...") |
|
|
|
|
|
all_translations = [] |
|
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) |
|
|
|
for i in range(num_batches): |
|
batch_start = i * BATCH_SIZE |
|
batch_end = batch_start + BATCH_SIZE |
|
batch_chunks = all_chunks[batch_start:batch_end] |
|
print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)") |
|
|
|
|
|
try: |
|
inputs = tokenizer( |
|
batch_chunks, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=1024 |
|
).to(device) |
|
max_length = 1024 |
|
|
|
max_input_length = inputs["input_ids"].shape[1] |
|
max_new_tokens = min(int(max_input_length * 1.2), max_length) |
|
|
|
print(f"Tokenized input (max_length={max_length})") |
|
for i, (text, input_ids) in enumerate(zip(batch_chunks, inputs["input_ids"])): |
|
print(f" Input {i + 1}: {len(input_ids)} tokens") |
|
print(f" Chunk {i + 1}: {repr(text)}...") |
|
for idx, ids in enumerate(inputs["input_ids"]): |
|
print(f" Input {idx+1}: {len(ids)} tokens") |
|
|
|
except Exception as e: |
|
print(f"Error during batch tokenization: {e}") |
|
return "[Error] Tokenization failed for a batch." |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
num_beams=4, |
|
|
|
early_stopping=False, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
|
|
print(f" Generation completed with max_new_tokens={max_new_tokens}") |
|
|
|
sequences = outputs.sequences |
|
for idx, seq in enumerate(sequences): |
|
print(f" Output {idx+1}: {len(seq)} tokens") |
|
|
|
batch_translations = tokenizer.batch_decode(sequences, skip_special_tokens=True) |
|
all_translations.extend(batch_translations) |
|
|
|
except Exception as e: |
|
print(f"Error during batch generation/decoding: {e}") |
|
return "[Error] Translation generation failed for a batch." |
|
|
|
|
|
|
|
|
|
final_output = "\n".join(all_translations) |
|
print("Batch translation finished.") |
|
return final_output |
|
|
|
|
|
|
|
input_textbox = gr.Textbox( |
|
lines=10, |
|
label="Input Text (Polish or English - Enter multiple lines/sentences)", |
|
placeholder=f"Enter text here. Longer sentences/lines will be split into chunks (max {MAX_WORDS_PER_CHUNK} words) prioritizing . ! ? and , breaks." |
|
) |
|
output_textbox = gr.Textbox(label="Translation Output", lines=10) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=translate_batch, |
|
inputs=input_textbox, |
|
outputs=output_textbox, |
|
title="π΅π± <-> π¬π§ Batch ByT5 Translator (Auto-Detect, Smart Chunking)", |
|
description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is automatically split into chunks of max {MAX_WORDS_PER_CHUNK} words, prioritizing breaks at . ! ? and ,", |
|
article="Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. The entire input box content is split into potential 'sentence segments' using . ! ? as delimiters.\n2. Each segment is checked for word count.\n3. If a segment is <= {MAX_WORDS_PER_CHUNK} words, it's treated as a single chunk.\n4. If a segment is > {MAX_WORDS_PER_CHUNK} words, it's further split internally using commas (,) as preferred break points.\n5. If a long segment has no commas, or comma splitting isn't sufficient, it falls back to breaking purely by word count near {MAX_WORDS_PER_CHUNK} to avoid excessively long chunks.\n6. These final chunks are batched and translated.", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
interface.launch() |