File size: 16,840 Bytes
64131ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43fe904
 
64131ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f98b884
64131ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d362e2e
64131ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# --- START OF FILE app.py ---

try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    # Define a dummy decorator that does nothing if 'spaces' isn't available
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                # This dummy decorator just returns the original function
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces() # Create an instance of the dummy class

import gradio as gr
import re # Import the regular expression module
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM
import torch # Or import tensorflow as tf
import os
import math
from huggingface_hub import hf_hub_download

# --- Configuration ---
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path
MODEL_PATH = "Gregniuki/pl-en-pl-v2" # Use your actual model path
MAX_WORDS_PER_CHUNK = 128 # Define the maximum words per chunk
BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput

# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU detected. Using CUDA.")
else:
    device = torch.device("cpu")
    print("No GPU detected. Using CPU.")

# --- Get Hugging Face Token from Secrets for Private Models ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID
    if HF_AUTH_TOKEN is None:
        print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.")
    else:
        print("HF_TOKEN found. Using token for model loading.")
else:
    print(f"Loading model from local path: {MODEL_PATH}")
    HF_AUTH_TOKEN = None # Don't use token for local paths


# --- Load Model and Tokenizer (once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )

    # --- Choose the correct model class ---
    # PyTorch (most common)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )
    model.to(device) # Move model to the determined device
    model.eval() # Set model to evaluation mode
    print(f"Using PyTorch model on device: {device}")

    # # TensorFlow (uncomment if your model is TF)
    # from transformers import TFAutoModelForSeq2SeqLM
    # import tensorflow as tf
    # model = TFAutoModelForSeq2SeqLM.from_pretrained(
    #     MODEL_PATH,
    #     token=HF_AUTH_TOKEN,
    #     trust_remote_code=False
    # )
    # # TF device placement is often automatic or managed via strategies
    # print("Using TensorFlow model.")

    print("Model and tokenizer loaded successfully.")

except Exception as e:
    print(f"FATAL Error loading model/tokenizer: {e}")
    if "401 Client Error" in str(e):
         error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
    else:
         error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
    # Raise error to prevent app launch if model loading fails
    raise RuntimeError(error_message)


# --- Helper Functions for Chunking ---

def chunk_sentence(sentence, max_words):
    """
    Splits a sentence (or line of text) into chunks ONLY if it exceeds max_words.
    If splitting is needed, it prioritizes splitting *after* sentence-ending
    punctuation (. ! ?) or commas (,) found within the first `max_words`.
    It looks for the *last* such punctuation within that limit.
    If no suitable punctuation is found, it splits strictly at `max_words`.
    """
    if not sentence or sentence.isspace():
        return []

    sentence = sentence.strip() # Ensure no leading/trailing whitespace
    words = sentence.split()
    word_count = len(words)

    # If the sentence is short enough, return it as a single chunk
    if word_count <= max_words:
        return [sentence]

    # If the sentence is too long, proceed with chunking
    chunks = []
    current_word_index = 0
    while current_word_index < word_count:
        # Determine the end index for the current potential chunk (non-inclusive)
        potential_end_word_index = min(current_word_index + max_words, word_count)

        # Assume we split at the max_words limit initially
        actual_end_word_index = potential_end_word_index

        # Check if we need to look for punctuation (i.e., if this chunk would be exactly max_words
        # and there's more text remaining, or if the remaining text itself is longer than max_words)
        # This check ensures we don't unnecessarily truncate if the remaining part is short.
        if potential_end_word_index < word_count:
             # Search backwards from the word *before* the potential end index
             # down to the start of the current segment for punctuation.
             best_punctuation_split_index = -1
             for i in range(potential_end_word_index - 1, current_word_index, -1):
                 # Check if the word at index 'i' ends with the desired punctuation
                 if words[i].endswith(('.', '!', '?')):
                     best_punctuation_split_index = i + 1 # Split *after* this word
                     break # Found the last suitable punctuation in the range

             # If we found a punctuation split point, use it
             if best_punctuation_split_index > current_word_index: # Ensure it's a valid index within the current segment
                 actual_end_word_index = best_punctuation_split_index
             # Else: No suitable punctuation found, stick with potential_end_word_index (split at max_words limit)

        # Safety check: Prevent creating an empty chunk if the split point is the same as the start
        # This can happen if the first word itself is very long or under unusual circumstances.
        # Force consuming at least one word if we are not at the end.
        if actual_end_word_index <= current_word_index and current_word_index < word_count:
             actual_end_word_index = current_word_index + 1
             print(f"Warning: Split point adjustment needed. Forced split after word index {current_word_index}.")


        # Extract the chunk words and join them
        chunk_words = words[current_word_index:actual_end_word_index]
        if chunk_words: # Ensure we don't add empty strings
            chunks.append(" ".join(chunk_words))

        # Update the starting index for the next chunk
        current_word_index = actual_end_word_index

        # Basic infinite loop prevention (should not be necessary with correct logic but safe)
        if current_word_index == word_count and len(chunks) > 0: # Normal exit condition
             break
        if current_word_index < word_count and actual_end_word_index <= current_word_index :
            print(f"ERROR: Chunking loop failed to advance. Aborting chunking for this sentence.")
            # Return partially chunked sentence or handle error appropriately
            # For simplicity, we might return the chunks found so far plus the rest unsplit
            remaining_words = words[current_word_index:]
            if remaining_words:
                chunks.append(" ".join(remaining_words))
            break # Exit loop

    return [chunk for chunk in chunks if chunk] # Final filter for empty strings

# --- Define the BATCH translation function ---
# Add GPU decorator for Spaces (adjust duration if needed)
@spaces.GPU
def translate_batch(
    text_input,
    num_beams,
    do_sample,
    temperature,
    top_k,
    top_p
):
    """
    Translates multi-line input text using batching and sentence chunking.
    Assumes auto-detection of language direction (no prefixes).
    Uses the updated chunking logic and supports generation controls.
    """
    if not text_input or text_input.strip() == "":
        return "[Error] Please enter some text to translate."

    print(f"Received input block for batch translation.")

    # 1. Split input into lines and clean
    lines = [line.strip() for line in text_input.splitlines() if line.strip()]
    if not lines:
        return "[Info] No valid text lines found in input."

    # 2. Chunk each line individually using the new logic
    all_chunks = []
    for line in lines:
        # Apply the new chunking logic to each line
        line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK)
        all_chunks.extend(line_chunks)

    if not all_chunks:
        return "[Info] No text chunks generated after processing input."

    print(f"Processing {len(all_chunks)} chunks in batches with parameters:")
    print(f"  num_beams: {num_beams}")
    print(f"  do_sample: {do_sample}")
    print(f"  temperature: {temperature}")
    print(f"  top_k: {top_k}")
    print(f"  top_p: {top_p}")

    # 3. Process chunks in batches
    all_translations = []
    num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)

    for i in range(num_batches):
        batch_start = i * BATCH_SIZE
        batch_end = batch_start + BATCH_SIZE
        batch_chunks = all_chunks[batch_start:batch_end]
        print(f"  Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)")

        # Tokenize the batch
        try:
            inputs = tokenizer(
                batch_chunks,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024 # Model's max input length
            ).to(device)

            # Estimate appropriate max_new_tokens based on input length
            max_input_length = inputs["input_ids"].shape[1]
            max_new_tokens = min(int(max_input_length * 1.2) + 20, 1024) # Increased multiplier for safety and added a buffer

            print(f"Tokenized input (batch max length={max_input_length}), setting max_new_tokens={max_new_tokens}")

        except Exception as e:
            print(f"Error during batch tokenization: {e}")
            all_translations.append(f"[Error tokenizing batch {i+1}]")
            continue # Skip to next batch or break

        # Prepare generation arguments
        generation_kwargs = {
            **inputs,
            "max_new_tokens": max_new_tokens,
            "early_stopping": True, # Usually good for translation
        }

        # Apply decoding strategy parameters
        if do_sample:
            generation_kwargs["do_sample"] = True
            generation_kwargs["temperature"] = temperature
            generation_kwargs["top_k"] = top_k if top_k > 0 else None # 0 means disable top_k
            generation_kwargs["top_p"] = top_p if top_p > 0 and top_p < 1.0 else None # 0 or 1 means disable top_p
            # When sampling, num_beams is typically 1. We force it here to avoid conflicts.
            generation_kwargs["num_beams"] = 1
            print(f"    Using Sampling: temp={temperature}, top_k={top_k}, top_p={top_p}")
        else:
            generation_kwargs["do_sample"] = False # Ensure greedy/beam search
            generation_kwargs["num_beams"] = num_beams
            print(f"    Using {'Greedy' if num_beams == 1 else 'Beam Search'}: num_beams={num_beams}")

        # Generate translations for the batch
        try:
            with torch.no_grad():
                outputs = model.generate(**generation_kwargs)

            print(f"    Generation completed for batch {i+1}")

            batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_translations.extend(batch_translations)

        except Exception as e:
            print(f"Error during batch generation/decoding: {e}")
            error_msg = f"[Error translating batch {i+1}]"
            all_translations.extend([error_msg] * len(batch_chunks))

    # 4. Join translated chunks back together
    final_output = "\n".join(all_translations)
    print("Batch translation finished.")
    return final_output


# --- Create Gradio Interface for Batch Translation ---
input_textbox = gr.Textbox(
    lines=10, # Allow more lines for batch input
    label="Input Text (Polish or English - Enter multiple lines/sentences)",
    placeholder=f"Enter text here. Lines longer than {MAX_WORDS_PER_CHUNK} words will be split, prioritizing breaks after . ! ? , near the limit."
)
output_textbox = gr.Textbox(label="Translation Output", lines=10)

# New generation control components
num_beams_slider = gr.Slider(
    minimum=1,
    maximum=10,
    step=1,
    value=4, # Default value for beam search
    label="Number of Beams",
    info="Controls the breadth of the search during decoding (Beam Search). Set to 1 for Greedy Search. Ignored if 'Enable Sampling' is checked."
)

do_sample_checkbox = gr.Checkbox(
    value=False, # Default to beam search
    label="Enable Sampling",
    info="If checked, enables probabilistic sampling using Temperature, Top-K, and Top-P. Overrides Beam Search (forces Number of Beams to 1)."
)

temperature_slider = gr.Slider(
    minimum=0.0,
    maximum=2.0,
    step=0.01, # Finer control for temperature
    value=1.0, # Default for sampling
    label="Temperature",
    info="Controls randomness. Lower=more deterministic, Higher=more random. Only active if 'Enable Sampling' is checked."
)

top_k_slider = gr.Slider(
    minimum=0, # 0 usually means disable Top-K
    maximum=200,
    step=1,
    value=50, # Common value
    label="Top-K",
    info="Number of highest probability vocabulary tokens to consider for sampling. 0 means no Top-K. Only active if 'Enable Sampling' is checked."
)

top_p_slider = gr.Slider(
    minimum=0.0,
    maximum=1.0,
    step=0.01, # Finer control for Top-P
    value=0.9, # Common value for nucleus sampling
    label="Top-P (Nucleus Sampling)",
    info="Fraction of highest probability tokens to consider whose cumulative probability exceeds P. Only active if 'Enable Sampling' is checked."
)


# Interface definition
interface = gr.Interface(
    fn=translate_batch,          # Use the batch function
    inputs=[
        input_textbox,
        num_beams_slider,
        do_sample_checkbox,
        temperature_slider,
        top_k_slider,
        top_p_slider
    ],
    outputs=output_textbox,
    title="πŸ‡΅πŸ‡± <-> πŸ‡¬πŸ‡§ Batch ByT5 Translator (Auto-Detect, Smart Chunking with Generation Controls)",
    description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is processed line by line. Lines longer than {MAX_WORDS_PER_CHUNK} words are split into chunks. You can adjust generation parameters below.",
    article=f"Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. Each line you enter is processed independently.\n2. If a line contains {MAX_WORDS_PER_CHUNK} words or fewer, it is translated as a single unit.\n3. If a line contains more than {MAX_WORDS_PER_CHUNK} words, it is split into smaller chunks.\n4. When splitting, the algorithm looks for the last punctuation mark (. ! ? ,) within the first {MAX_WORDS_PER_CHUNK} words to use as a natural break point.\n5. If no suitable punctuation is found in that range, the line is split exactly at the {MAX_WORDS_PER_CHUNK}-word limit.\n6. This process repeats for the remainder of the line until all parts are below the word limit.\n7. These final chunks are then translated in batches.\n\nGeneration Controls:\n- **Number of Beams**: Higher values (e.g., 4-10) can produce more fluent but less diverse output. Set to 1 for greedy decoding.\n- **Enable Sampling**: Activates probabilistic generation. When checked, the model will use Temperature, Top-K, and Top-P, and 'Number of Beams' will be forced to 1.\n- **Temperature**: Controls randomness. A value of 1.0 means no change to probabilities. Lower values (e.g., 0.7) make the output more focused and less random. Higher values (e.g., 1.5) increase randomness and creativity.\n- **Top-K**: The model considers only the 'K' most likely next tokens. A value of 0 effectively disables it. Useful for limiting very unlikely predictions.\n- **Top-P (Nucleus Sampling)**: The model considers the smallest set of tokens whose cumulative probability exceeds 'P'. This dynamically adjusts the number of tokens considered based on the probability distribution, making it more flexible than Top-K.",
    allow_flagging="never"
)

# --- Launch the App ---
if __name__ == "__main__":
    # Set share=True for a public link if running locally, not needed on Spaces
    interface.launch()
# --- END OF FILE app.py ---