File size: 14,857 Bytes
07fcc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed95ef1
0e5692c
07fcc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0f366a
 
 
 
 
77ecbc1
e0f366a
c9ac64e
 
e0f366a
3071d3f
e0f366a
c9ac64e
b39a62e
c9ac64e
8643097
e0f366a
 
07fcc31
 
 
 
 
 
 
 
 
 
e0f366a
07fcc31
c9ac64e
84285e8
f6f9d97
 
07fcc31
e0f366a
 
 
 
 
 
 
 
07fcc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    # Define a dummy decorator that does nothing if 'spaces' isn't available
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                # This dummy decorator just returns the original function
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces() # Create an instance of the dummy class

import gradio as gr
import re # Import the regular expression module
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM
import torch # Or import tensorflow as tf
import os
import math
# Requires Gradio version supporting spaces.GPU decorator if running on Spaces
# Might need: from gradio.external import spaces <- if spaces not directly available
#import gradio.external as spaces # Use this import path
from huggingface_hub import hf_hub_download

# --- Configuration ---
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path
MODEL_PATH = "Gregniuki/pl-en-pl" # Use your actual model path
MAX_WORDS_PER_CHUNK = 44 # Define the maximum words per chunk
BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput

# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU detected. Using CUDA.")
else:
    device = torch.device("cpu")
    print("No GPU detected. Using CPU.")

# --- Get Hugging Face Token from Secrets for Private Models ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID
    if HF_AUTH_TOKEN is None:
        print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.")
    else:
        print("HF_TOKEN found. Using token for model loading.")
else:
    print(f"Loading model from local path: {MODEL_PATH}")
    HF_AUTH_TOKEN = None # Don't use token for local paths


# --- Load Model and Tokenizer (once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )

    # --- Choose the correct model class ---
    # PyTorch (most common)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )
    model.to(device) # Move model to the determined device
    model.eval() # Set model to evaluation mode
    print(f"Using PyTorch model on device: {device}")

    # # TensorFlow (uncomment if your model is TF)
    # from transformers import TFAutoModelForSeq2SeqLM
    # import tensorflow as tf
    # model = TFAutoModelForSeq2SeqLM.from_pretrained(
    #     MODEL_PATH,
    #     token=HF_AUTH_TOKEN,
    #     trust_remote_code=False
    # )
    # # TF device placement is often automatic or managed via strategies
    # print("Using TensorFlow model.")

    print("Model and tokenizer loaded successfully.")

except Exception as e:
    print(f"FATAL Error loading model/tokenizer: {e}")
    if "401 Client Error" in str(e):
         error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
    else:
         error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
    # Raise error to prevent app launch if model loading fails
    raise RuntimeError(error_message)


# --- Helper Functions for Chunking ---

def split_long_segment_by_comma_or_fallback(segment, max_words):
    """
    Splits a long segment (already known > max_words) primarily by commas,
    falling back to simple word splitting if needed.
    """
    if not segment or segment.isspace():
        return []

    # 1. Attempt to split by commas, keeping the comma and trailing whitespace
    # re.split splits *after* the pattern. (?<=,) looks behind for a comma. \s* matches trailing whitespace.
    comma_parts = re.split(r'(?<=,)\s*', segment)
    comma_parts = [p.strip() for p in comma_parts if p.strip()] # Trim and filter empty parts

    # If no commas found or splitting yielded strange results, fall back to word splitting
    if not comma_parts or (len(comma_parts) == 1 and len(comma_parts[0].split()) > max_words):
        # print(f"Debug: Falling back to word split for segment: '{segment[:100]}...'") # Optional debug
        # Fallback: Simple word-based chunking
        words = segment.split()
        segment_chunks = []
        current_chunk_words = []
        for word in words:
            current_chunk_words.append(word)
            # If adding the current word makes the chunk too long, finalize the previous words
            # and start a new chunk with the current word.
            if len(current_chunk_words) > max_words:
                # Add the chunk excluding the word that pushed it over
                segment_chunks.append(" ".join(current_chunk_words[:-1]))
                # Start a new chunk with the word that pushed it over
                current_chunk_words = [word]
            # Edge case: If the chunk is exactly max_words, finalize it unless it's the very first word.
            # This prevents a single chunk from staying at max_words forever if no further breaks are found.
            elif len(current_chunk_words) == max_words:
                 segment_chunks.append(" ".join(current_chunk_words))
                 current_chunk_words = []


        # Add any remaining words
        if current_chunk_words:
            segment_chunks.append(" ".join(current_chunk_words))

        return segment_chunks

    # 2. Recombine comma-separated parts, respecting max_words
    segment_chunks = []
    current_chunk_parts = [] # List to hold comma-separated strings for the current chunk
    current_chunk_word_count = 0

    for i, part in enumerate(comma_parts):
        part_word_count = len(part.split())

        # Check if adding this part makes the current chunk exceed max_words.
        # Condition `current_chunk_word_count > 0` ensures we don't break before adding the first part.
        # If the first part itself is > max_words, the fallback above handles it.
        if current_chunk_word_count > 0 and (current_chunk_word_count + part_word_count > max_words):
            # Finalize the current chunk (join the collected parts)
            segment_chunks.append(" ".join(current_chunk_parts).strip()) # Join with space, trim result
            # Start a new chunk with the current part
            current_chunk_parts = [part]
            current_chunk_word_count = part_word_count
        else:
            # Add the part to the current chunk
            current_chunk_parts.append(part)
            current_chunk_word_count += part_word_count

    # Add any remaining parts as the last chunk for this segment
    if current_chunk_parts:
        segment_chunks.append(" ".join(current_chunk_parts).strip())

    return segment_chunks


def chunk_sentence(sentence, max_words):
    """
    Splits text into chunks based on max words, prioritizing sentence-ending punctuation (. ! ?),
    then commas (,) if the chunk is already >= max_words, falling back to word split.
    Processes the input line as potentially containing multiple sentences.
    """
    if not sentence or sentence.isspace():
        return []

    all_final_chunks = []

    # 1. Split the input line into potential "sentence segments" at . ! ?
    # Use regex split with lookbehind to split *after* the punctuation and space.
    # This yields segments that end in . ! ? (except possibly the very last segment).
    # Example: "Hello world. How are you? And you?" -> ["Hello world.", "How are you?", "And you?"]
    # Example: "Part one, part two. Part three." -> ["Part one, part two.", "Part three."]
    # Example: "No punctuation here" -> ["No punctuation here"]
    sentence_segments = re.split(r'(?<=[.!?])\s*', sentence)

    # Filter out empty strings that might result from splitting
    sentence_segments = [s.strip() for s in sentence_segments if s.strip()]

    # 2. Process each sentence segment
    for segment in sentence_segments:
        segment_word_count = len(segment.split())

        if segment_word_count <= max_words:
            # Segment is short enough, add directly
            all_final_chunks.append(segment)
        else:
            # Segment is too long, apply comma splitting or fallback word splitting
            comma_based_chunks = split_long_segment_by_comma_or_fallback(segment, max_words)
            all_final_chunks.extend(comma_based_chunks)

    # Ensure no empty strings sneak through at the end
    return [chunk for chunk in all_final_chunks if chunk.strip()]


# --- Define the BATCH translation function ---
# Add GPU decorator for Spaces (adjust duration if needed)
@spaces.GPU
def translate_batch(text_input):
    """
    Translates multi-line input text using batching and sentence chunking.
    Assumes auto-detection of language direction (no prefixes).
    """
    if not text_input or text_input.strip() == "":
        return "[Error] Please enter some text to translate."

    print(f"Received input block for batch translation.")

    # 1. Split input into potential sentences (lines) and clean
    # Then chunk each line using the sophisticated chunk_sentence function
    lines = [line.strip() for line in text_input.splitlines() if line.strip()]
    if not lines:
        return "[Info] No valid text lines found in input."

    # 2. Chunk lines using the new logic
    all_chunks = []
    for line in lines:
        # Process each line as a potential multi-sentence block for chunking
        line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK)
        all_chunks.extend(line_chunks)

    if not all_chunks:
        return "[Info] No text chunks generated after processing input."

    print(f"Processing {len(all_chunks)} chunks in batches...")

    # 3. Process chunks in batches
    all_translations = []
    num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)

    for i in range(num_batches):
        batch_start = i * BATCH_SIZE
        batch_end = batch_start + BATCH_SIZE
        batch_chunks = all_chunks[batch_start:batch_end]
        print(f"  Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)")

        # Tokenize the batch
        try:
            inputs = tokenizer(
                batch_chunks,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024
            ).to(device)
            max_length = 1024  # your specified model max length
            
            max_input_length = inputs["input_ids"].shape[1]
            max_new_tokens = min(int(max_input_length * 1.2), max_length)

            print(f"Tokenized input (max_length={max_length})")
            for i, (text, input_ids) in enumerate(zip(batch_chunks, inputs["input_ids"])):
                print(f"  Input {i + 1}: {len(input_ids)} tokens")
                print(f"    Chunk {i + 1}: {repr(text)}...")  # Print first 100 chars to keep output manageableu
            for idx, ids in enumerate(inputs["input_ids"]):
                print(f"    Input {idx+1}: {len(ids)} tokens")

        except Exception as e:
            print(f"Error during batch tokenization: {e}")
            return "[Error] Tokenization failed for a batch."

        # Generate translations for the batch
        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    num_beams=4,
                   # no_repeat_ngram_size=3,
                    early_stopping=False,
                    return_dict_in_generate=True,
                    output_scores=True
                )

            print(f"    Generation completed with max_new_tokens={max_new_tokens}")

            sequences = outputs.sequences
            for idx, seq in enumerate(sequences):
                print(f"    Output {idx+1}: {len(seq)} tokens")

            batch_translations = tokenizer.batch_decode(sequences, skip_special_tokens=True)
            all_translations.extend(batch_translations)

        except Exception as e:
            print(f"Error during batch generation/decoding: {e}")
            return "[Error] Translation generation failed for a batch."
    # 4. Join translated chunks back together
    # Simple join with newline. The chunking logic aims to keep sentences/clauses together,
    # so joining by newline should preserve the overall structure reasonably well,
    # though it might not exactly match the original line breaks if chunking occurred within an original line.
    final_output = "\n".join(all_translations)
    print("Batch translation finished.")
    return final_output


# --- Create Gradio Interface for Batch Translation ---
input_textbox = gr.Textbox(
    lines=10, # Allow more lines for batch input
    label="Input Text (Polish or English - Enter multiple lines/sentences)",
    placeholder=f"Enter text here. Longer sentences/lines will be split into chunks (max {MAX_WORDS_PER_CHUNK} words) prioritizing . ! ? and , breaks."
)
output_textbox = gr.Textbox(label="Translation Output", lines=10)

# Interface definition
interface = gr.Interface(
    fn=translate_batch,          # Use the batch function
    inputs=input_textbox,
    outputs=output_textbox,
    title="πŸ‡΅πŸ‡± <-> πŸ‡¬πŸ‡§ Batch ByT5 Translator (Auto-Detect, Smart Chunking)",
    description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is automatically split into chunks of max {MAX_WORDS_PER_CHUNK} words, prioritizing breaks at . ! ? and ,",
    article="Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. The entire input box content is split into potential 'sentence segments' using . ! ? as delimiters.\n2. Each segment is checked for word count.\n3. If a segment is <= {MAX_WORDS_PER_CHUNK} words, it's treated as a single chunk.\n4. If a segment is > {MAX_WORDS_PER_CHUNK} words, it's further split internally using commas (,) as preferred break points.\n5. If a long segment has no commas, or comma splitting isn't sufficient, it falls back to breaking purely by word count near {MAX_WORDS_PER_CHUNK} to avoid excessively long chunks.\n6. These final chunks are batched and translated.",
    allow_flagging="never"
)

# --- Launch the App ---
if __name__ == "__main__":
    # Set share=True for a public link if running locally, not needed on Spaces
    interface.launch()