|
|
|
|
|
try: |
|
import spaces |
|
print("'spaces' module imported successfully.") |
|
except ImportError: |
|
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") |
|
|
|
class DummySpaces: |
|
def GPU(self, *args, **kwargs): |
|
def decorator(func): |
|
|
|
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") |
|
return func |
|
return decorator |
|
spaces = DummySpaces() |
|
|
|
import gradio as gr |
|
import re |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import os |
|
import math |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
MODEL_PATH = "Gregniuki/pl-en-pl-v2" |
|
MAX_WORDS_PER_CHUNK = 128 |
|
BATCH_SIZE = 8 |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
print("GPU detected. Using CUDA.") |
|
else: |
|
device = torch.device("cpu") |
|
print("No GPU detected. Using CPU.") |
|
|
|
|
|
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") |
|
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): |
|
if HF_AUTH_TOKEN is None: |
|
print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.") |
|
else: |
|
print("HF_TOKEN found. Using token for model loading.") |
|
else: |
|
print(f"Loading model from local path: {MODEL_PATH}") |
|
HF_AUTH_TOKEN = None |
|
|
|
|
|
|
|
print(f"Loading model and tokenizer from: {MODEL_PATH}") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_PATH, |
|
token=HF_AUTH_TOKEN, |
|
trust_remote_code=False |
|
) |
|
|
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
MODEL_PATH, |
|
token=HF_AUTH_TOKEN, |
|
trust_remote_code=False |
|
) |
|
model.to(device) |
|
model.eval() |
|
print(f"Using PyTorch model on device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Model and tokenizer loaded successfully.") |
|
|
|
except Exception as e: |
|
print(f"FATAL Error loading model/tokenizer: {e}") |
|
if "401 Client Error" in str(e): |
|
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." |
|
else: |
|
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" |
|
|
|
raise RuntimeError(error_message) |
|
|
|
|
|
|
|
|
|
def chunk_sentence(sentence, max_words): |
|
""" |
|
Splits a sentence (or line of text) into chunks ONLY if it exceeds max_words. |
|
If splitting is needed, it prioritizes splitting *after* sentence-ending |
|
punctuation (. ! ?) or commas (,) found within the first `max_words`. |
|
It looks for the *last* such punctuation within that limit. |
|
If no suitable punctuation is found, it splits strictly at `max_words`. |
|
""" |
|
if not sentence or sentence.isspace(): |
|
return [] |
|
|
|
sentence = sentence.strip() |
|
words = sentence.split() |
|
word_count = len(words) |
|
|
|
|
|
if word_count <= max_words: |
|
return [sentence] |
|
|
|
|
|
chunks = [] |
|
current_word_index = 0 |
|
while current_word_index < word_count: |
|
|
|
potential_end_word_index = min(current_word_index + max_words, word_count) |
|
|
|
|
|
actual_end_word_index = potential_end_word_index |
|
|
|
|
|
|
|
|
|
if potential_end_word_index < word_count: |
|
|
|
|
|
best_punctuation_split_index = -1 |
|
for i in range(potential_end_word_index - 1, current_word_index, -1): |
|
|
|
if words[i].endswith(('.', '!', '?')): |
|
best_punctuation_split_index = i + 1 |
|
break |
|
|
|
|
|
if best_punctuation_split_index > current_word_index: |
|
actual_end_word_index = best_punctuation_split_index |
|
|
|
|
|
|
|
|
|
|
|
if actual_end_word_index <= current_word_index and current_word_index < word_count: |
|
actual_end_word_index = current_word_index + 1 |
|
print(f"Warning: Split point adjustment needed. Forced split after word index {current_word_index}.") |
|
|
|
|
|
|
|
chunk_words = words[current_word_index:actual_end_word_index] |
|
if chunk_words: |
|
chunks.append(" ".join(chunk_words)) |
|
|
|
|
|
current_word_index = actual_end_word_index |
|
|
|
|
|
if current_word_index == word_count and len(chunks) > 0: |
|
break |
|
if current_word_index < word_count and actual_end_word_index <= current_word_index : |
|
print(f"ERROR: Chunking loop failed to advance. Aborting chunking for this sentence.") |
|
|
|
|
|
remaining_words = words[current_word_index:] |
|
if remaining_words: |
|
chunks.append(" ".join(remaining_words)) |
|
break |
|
|
|
return [chunk for chunk in chunks if chunk] |
|
|
|
|
|
|
|
@spaces.GPU |
|
def translate_batch( |
|
text_input, |
|
num_beams, |
|
do_sample, |
|
temperature, |
|
top_k, |
|
top_p |
|
): |
|
""" |
|
Translates multi-line input text using batching and sentence chunking. |
|
Assumes auto-detection of language direction (no prefixes). |
|
Uses the updated chunking logic and supports generation controls. |
|
""" |
|
if not text_input or text_input.strip() == "": |
|
return "[Error] Please enter some text to translate." |
|
|
|
print(f"Received input block for batch translation.") |
|
|
|
|
|
lines = [line.strip() for line in text_input.splitlines() if line.strip()] |
|
if not lines: |
|
return "[Info] No valid text lines found in input." |
|
|
|
|
|
all_chunks = [] |
|
for line in lines: |
|
|
|
line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK) |
|
all_chunks.extend(line_chunks) |
|
|
|
if not all_chunks: |
|
return "[Info] No text chunks generated after processing input." |
|
|
|
print(f"Processing {len(all_chunks)} chunks in batches with parameters:") |
|
print(f" num_beams: {num_beams}") |
|
print(f" do_sample: {do_sample}") |
|
print(f" temperature: {temperature}") |
|
print(f" top_k: {top_k}") |
|
print(f" top_p: {top_p}") |
|
|
|
|
|
all_translations = [] |
|
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) |
|
|
|
for i in range(num_batches): |
|
batch_start = i * BATCH_SIZE |
|
batch_end = batch_start + BATCH_SIZE |
|
batch_chunks = all_chunks[batch_start:batch_end] |
|
print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)") |
|
|
|
|
|
try: |
|
inputs = tokenizer( |
|
batch_chunks, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=1024 |
|
).to(device) |
|
|
|
|
|
max_input_length = inputs["input_ids"].shape[1] |
|
max_new_tokens = min(int(max_input_length * 1.2) + 20, 1024) |
|
|
|
print(f"Tokenized input (batch max length={max_input_length}), setting max_new_tokens={max_new_tokens}") |
|
|
|
except Exception as e: |
|
print(f"Error during batch tokenization: {e}") |
|
all_translations.append(f"[Error tokenizing batch {i+1}]") |
|
continue |
|
|
|
|
|
generation_kwargs = { |
|
**inputs, |
|
"max_new_tokens": max_new_tokens, |
|
"early_stopping": True, |
|
} |
|
|
|
|
|
if do_sample: |
|
generation_kwargs["do_sample"] = True |
|
generation_kwargs["temperature"] = temperature |
|
generation_kwargs["top_k"] = top_k if top_k > 0 else None |
|
generation_kwargs["top_p"] = top_p if top_p > 0 and top_p < 1.0 else None |
|
|
|
generation_kwargs["num_beams"] = 1 |
|
print(f" Using Sampling: temp={temperature}, top_k={top_k}, top_p={top_p}") |
|
else: |
|
generation_kwargs["do_sample"] = False |
|
generation_kwargs["num_beams"] = num_beams |
|
print(f" Using {'Greedy' if num_beams == 1 else 'Beam Search'}: num_beams={num_beams}") |
|
|
|
|
|
try: |
|
with torch.no_grad(): |
|
outputs = model.generate(**generation_kwargs) |
|
|
|
print(f" Generation completed for batch {i+1}") |
|
|
|
batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
all_translations.extend(batch_translations) |
|
|
|
except Exception as e: |
|
print(f"Error during batch generation/decoding: {e}") |
|
error_msg = f"[Error translating batch {i+1}]" |
|
all_translations.extend([error_msg] * len(batch_chunks)) |
|
|
|
|
|
final_output = "\n".join(all_translations) |
|
print("Batch translation finished.") |
|
return final_output |
|
|
|
|
|
|
|
input_textbox = gr.Textbox( |
|
lines=10, |
|
label="Input Text (Polish or English - Enter multiple lines/sentences)", |
|
placeholder=f"Enter text here. Lines longer than {MAX_WORDS_PER_CHUNK} words will be split, prioritizing breaks after . ! ? , near the limit." |
|
) |
|
output_textbox = gr.Textbox(label="Translation Output", lines=10) |
|
|
|
|
|
num_beams_slider = gr.Slider( |
|
minimum=1, |
|
maximum=10, |
|
step=1, |
|
value=4, |
|
label="Number of Beams", |
|
info="Controls the breadth of the search during decoding (Beam Search). Set to 1 for Greedy Search. Ignored if 'Enable Sampling' is checked." |
|
) |
|
|
|
do_sample_checkbox = gr.Checkbox( |
|
value=False, |
|
label="Enable Sampling", |
|
info="If checked, enables probabilistic sampling using Temperature, Top-K, and Top-P. Overrides Beam Search (forces Number of Beams to 1)." |
|
) |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
step=0.01, |
|
value=1.0, |
|
label="Temperature", |
|
info="Controls randomness. Lower=more deterministic, Higher=more random. Only active if 'Enable Sampling' is checked." |
|
) |
|
|
|
top_k_slider = gr.Slider( |
|
minimum=0, |
|
maximum=200, |
|
step=1, |
|
value=50, |
|
label="Top-K", |
|
info="Number of highest probability vocabulary tokens to consider for sampling. 0 means no Top-K. Only active if 'Enable Sampling' is checked." |
|
) |
|
|
|
top_p_slider = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.9, |
|
label="Top-P (Nucleus Sampling)", |
|
info="Fraction of highest probability tokens to consider whose cumulative probability exceeds P. Only active if 'Enable Sampling' is checked." |
|
) |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=translate_batch, |
|
inputs=[ |
|
input_textbox, |
|
num_beams_slider, |
|
do_sample_checkbox, |
|
temperature_slider, |
|
top_k_slider, |
|
top_p_slider |
|
], |
|
outputs=output_textbox, |
|
title="π΅π± <-> π¬π§ Batch ByT5 Translator (Auto-Detect, Smart Chunking with Generation Controls)", |
|
description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is processed line by line. Lines longer than {MAX_WORDS_PER_CHUNK} words are split into chunks. You can adjust generation parameters below.", |
|
article=f"Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. Each line you enter is processed independently.\n2. If a line contains {MAX_WORDS_PER_CHUNK} words or fewer, it is translated as a single unit.\n3. If a line contains more than {MAX_WORDS_PER_CHUNK} words, it is split into smaller chunks.\n4. When splitting, the algorithm looks for the last punctuation mark (. ! ? ,) within the first {MAX_WORDS_PER_CHUNK} words to use as a natural break point.\n5. If no suitable punctuation is found in that range, the line is split exactly at the {MAX_WORDS_PER_CHUNK}-word limit.\n6. This process repeats for the remainder of the line until all parts are below the word limit.\n7. These final chunks are then translated in batches.\n\nGeneration Controls:\n- **Number of Beams**: Higher values (e.g., 4-10) can produce more fluent but less diverse output. Set to 1 for greedy decoding.\n- **Enable Sampling**: Activates probabilistic generation. When checked, the model will use Temperature, Top-K, and Top-P, and 'Number of Beams' will be forced to 1.\n- **Temperature**: Controls randomness. A value of 1.0 means no change to probabilities. Lower values (e.g., 0.7) make the output more focused and less random. Higher values (e.g., 1.5) increase randomness and creativity.\n- **Top-K**: The model considers only the 'K' most likely next tokens. A value of 0 effectively disables it. Useful for limiting very unlikely predictions.\n- **Top-P (Nucleus Sampling)**: The model considers the smallest set of tokens whose cumulative probability exceeds 'P'. This dynamically adjusts the number of tokens considered based on the probability distribution, making it more flexible than Top-K.", |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
interface.launch() |
|
|