File size: 14,433 Bytes
67fee3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a850f79
67fee3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbb31ae
67fee3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# --- START OF FILE app (5).py ---

try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    # Define a dummy decorator that does nothing if 'spaces' isn't available
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                # This dummy decorator just returns the original function
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces() # Create an instance of the dummy class

import gradio as gr
import re # Import the regular expression module
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM
import torch # Or import tensorflow as tf
import os
import math
# Requires Gradio version supporting spaces.GPU decorator if running on Spaces
# Might need: from gradio.external import spaces <- if spaces not directly available
#import gradio.external as spaces # Use this import path
from huggingface_hub import hf_hub_download

# --- Configuration ---
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path
MODEL_PATH = "Gregniuki/pl-en-pl" # Use your actual model path
MAX_WORDS_PER_CHUNK = 55 # Define the maximum words per chunk
BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput

# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU detected. Using CUDA.")
else:
    device = torch.device("cpu")
    print("No GPU detected. Using CPU.")

# --- Get Hugging Face Token from Secrets for Private Models ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID
    if HF_AUTH_TOKEN is None:
        print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.")
    else:
        print("HF_TOKEN found. Using token for model loading.")
else:
    print(f"Loading model from local path: {MODEL_PATH}")
    HF_AUTH_TOKEN = None # Don't use token for local paths


# --- Load Model and Tokenizer (once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )

    # --- Choose the correct model class ---
    # PyTorch (most common)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )
    model.to(device) # Move model to the determined device
    model.eval() # Set model to evaluation mode
    print(f"Using PyTorch model on device: {device}")

    # # TensorFlow (uncomment if your model is TF)
    # from transformers import TFAutoModelForSeq2SeqLM
    # import tensorflow as tf
    # model = TFAutoModelForSeq2SeqLM.from_pretrained(
    #     MODEL_PATH,
    #     token=HF_AUTH_TOKEN,
    #     trust_remote_code=False
    # )
    # # TF device placement is often automatic or managed via strategies
    # print("Using TensorFlow model.")

    print("Model and tokenizer loaded successfully.")

except Exception as e:
    print(f"FATAL Error loading model/tokenizer: {e}")
    if "401 Client Error" in str(e):
         error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
    else:
         error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
    # Raise error to prevent app launch if model loading fails
    raise RuntimeError(error_message)


# --- Helper Functions for Chunking ---

def chunk_sentence(sentence, max_words):
    """
    Splits a sentence (or line of text) into chunks ONLY if it exceeds max_words.
    If splitting is needed, it prioritizes splitting *after* sentence-ending
    punctuation (. ! ?) or commas (,) found within the first `max_words`.
    It looks for the *last* such punctuation within that limit.
    If no suitable punctuation is found, it splits strictly at `max_words`.
    """
    if not sentence or sentence.isspace():
        return []

    sentence = sentence.strip() # Ensure no leading/trailing whitespace
    words = sentence.split()
    word_count = len(words)

    # If the sentence is short enough, return it as a single chunk
    if word_count <= max_words:
        return [sentence]

    # If the sentence is too long, proceed with chunking
    chunks = []
    current_word_index = 0
    while current_word_index < word_count:
        # Determine the end index for the current potential chunk (non-inclusive)
        potential_end_word_index = min(current_word_index + max_words, word_count)

        # Assume we split at the max_words limit initially
        actual_end_word_index = potential_end_word_index

        # Check if we need to look for punctuation (i.e., if this chunk would be exactly max_words
        # and there's more text remaining, or if the remaining text itself is longer than max_words)
        # This check ensures we don't unnecessarily truncate if the remaining part is short.
        if potential_end_word_index < word_count:
             # Search backwards from the word *before* the potential end index
             # down to the start of the current segment for punctuation.
             best_punctuation_split_index = -1
             for i in range(potential_end_word_index - 1, current_word_index, -1):
                 # Check if the word at index 'i' ends with the desired punctuation
                 if words[i].endswith(('.', '!', '?', ',')):
                     best_punctuation_split_index = i + 1 # Split *after* this word
                     break # Found the last suitable punctuation in the range

             # If we found a punctuation split point, use it
             if best_punctuation_split_index > current_word_index: # Ensure it's a valid index within the current segment
                 actual_end_word_index = best_punctuation_split_index
             # Else: No suitable punctuation found, stick with potential_end_word_index (split at max_words limit)

        # Safety check: Prevent creating an empty chunk if the split point is the same as the start
        # This can happen if the first word itself is very long or under unusual circumstances.
        # Force consuming at least one word if we are not at the end.
        if actual_end_word_index <= current_word_index and current_word_index < word_count:
             actual_end_word_index = current_word_index + 1
             print(f"Warning: Split point adjustment needed. Forced split after word index {current_word_index}.")


        # Extract the chunk words and join them
        chunk_words = words[current_word_index:actual_end_word_index]
        if chunk_words: # Ensure we don't add empty strings
            chunks.append(" ".join(chunk_words))

        # Update the starting index for the next chunk
        current_word_index = actual_end_word_index

        # Basic infinite loop prevention (should not be necessary with correct logic but safe)
        if current_word_index == word_count and len(chunks) > 0: # Normal exit condition
             break
        if current_word_index < word_count and actual_end_word_index <= current_word_index :
            print(f"ERROR: Chunking loop failed to advance. Aborting chunking for this sentence.")
            # Return partially chunked sentence or handle error appropriately
            # For simplicity, we might return the chunks found so far plus the rest unsplit
            remaining_words = words[current_word_index:]
            if remaining_words:
                chunks.append(" ".join(remaining_words))
            break # Exit loop

    return [chunk for chunk in chunks if chunk] # Final filter for empty strings

# --- Define the BATCH translation function ---
# Add GPU decorator for Spaces (adjust duration if needed)
@spaces.GPU
def translate_batch(text_input):
    """
    Translates multi-line input text using batching and sentence chunking.
    Assumes auto-detection of language direction (no prefixes).
    Uses the updated chunking logic.
    """
    if not text_input or text_input.strip() == "":
        return "[Error] Please enter some text to translate."

    print(f"Received input block for batch translation.")

    # 1. Split input into lines and clean
    lines = [line.strip() for line in text_input.splitlines() if line.strip()]
    if not lines:
        return "[Info] No valid text lines found in input."

    # 2. Chunk each line individually using the new logic
    all_chunks = []
    for line in lines:
        # Apply the new chunking logic to each line
        line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK)
        all_chunks.extend(line_chunks)

    if not all_chunks:
        return "[Info] No text chunks generated after processing input."

    print(f"Processing {len(all_chunks)} chunks in batches...")

    # 3. Process chunks in batches
    all_translations = []
    num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)

    for i in range(num_batches):
        batch_start = i * BATCH_SIZE
        batch_end = batch_start + BATCH_SIZE
        batch_chunks = all_chunks[batch_start:batch_end]
        print(f"  Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)")

        # Tokenize the batch
        try:
            inputs = tokenizer(
                batch_chunks,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024 # Model's max input length
            ).to(device)

            # Estimate appropriate max_new_tokens based on input length
            # A simple heuristic: allow for some expansion, but cap at model max length
            max_input_length = inputs["input_ids"].shape[1]
            # Allow up to 20% expansion, capped at 1024 total tokens (input+output) if needed,
            # or just a fixed reasonably large number if expansion is less predictable.
            # Let's use a multiplier + cap for seq2seq
            max_new_tokens = min(int(max_input_length * 1.2) + 10, 1024) # Increased multiplier for safety

            print(f"Tokenized input (batch max length={max_input_length}), setting max_new_tokens={max_new_tokens}")
            # Optional: print token counts per input for debugging
            # for idx, ids in enumerate(inputs["input_ids"]):
            #     print(f"    Input {idx+1}: {len(ids)} tokens for chunk: '{batch_chunks[idx][:50]}...'")

        except Exception as e:
            print(f"Error during batch tokenization: {e}")
            # Consider returning partial results or a specific error
            all_translations.append(f"[Error tokenizing batch {i+1}]")
            continue # Skip to next batch or break

        # Generate translations for the batch
        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                   # no_repeat_ngram_size=3, # Consider if needed for model
                    early_stopping=True, # Usually good for translation
                    # Remove output_scores unless needed for specific analysis
                    # return_dict_in_generate=True, # Keep if you use outputs.sequences
                    # output_scores=True
                )

            print(f"    Generation completed for batch {i+1}")

            # Use default output which is usually the sequences tensor
            batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_translations.extend(batch_translations)

        except Exception as e:
            print(f"Error during batch generation/decoding: {e}")
            # Append error messages for the failed chunks in this batch
            error_msg = f"[Error translating batch {i+1}]"
            all_translations.extend([error_msg] * len(batch_chunks))
            # Consider if you want to stop processing or continue with next batches

    # 4. Join translated chunks back together
    # Simple join with newline. This respects that each chunk was processed independently.
    final_output = "\n".join(all_translations)
    print("Batch translation finished.")
    return final_output


# --- Create Gradio Interface for Batch Translation ---
input_textbox = gr.Textbox(
    lines=10, # Allow more lines for batch input
    label="Input Text (Polish or English - Enter multiple lines/sentences)",
    placeholder=f"Enter text here. Lines longer than {MAX_WORDS_PER_CHUNK} words will be split, prioritizing breaks after . ! ? , near the limit."
)
output_textbox = gr.Textbox(label="Translation Output", lines=10)

# Interface definition
interface = gr.Interface(
    fn=translate_batch,          # Use the batch function
    inputs=input_textbox,
    outputs=output_textbox,
    title="πŸ‡΅πŸ‡± <-> πŸ‡¬πŸ‡§ Batch ByT5 Translator (Auto-Detect, Smart Chunking)",
    description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is processed line by line. Lines longer than {MAX_WORDS_PER_CHUNK} words are split into chunks.",
    # Updated Article explaining the new logic
    article=f"Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. Each line you enter is processed independently.\n2. If a line contains {MAX_WORDS_PER_CHUNK} words or fewer, it is translated as a single unit.\n3. If a line contains more than {MAX_WORDS_PER_CHUNK} words, it is split into smaller chunks.\n4. When splitting, the algorithm looks for the last punctuation mark (. ! ? ,) within the first {MAX_WORDS_PER_CHUNK} words to use as a natural break point.\n5. If no suitable punctuation is found in that range, the line is split exactly at the {MAX_WORDS_PER_CHUNK}-word limit.\n6. This process repeats for the remainder of the line until all parts are below the word limit.\n7. These final chunks are then translated in batches.",
    allow_flagging="never"
)

# --- Launch the App ---
if __name__ == "__main__":
    # Set share=True for a public link if running locally, not needed on Spaces
    interface.launch()
# --- END OF FILE app (5).py ---