File size: 7,596 Bytes
4db7c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# --- START OF FILE app.py ---

import sys
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from dotenv import load_dotenv

# --- FIX: Add project root to Python's path ---
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)

# --- Updated Spaces import for Zero-GPU compatibility ---
try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces()

# --- Step 1: Hugging Face Authentication ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

if not HF_TOKEN:
    raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.")

print("--- Logging in to Hugging Face Hub ---")
login(token=HF_TOKEN)


# --- Step 2: Initialize Model and Tokenizer (Load Once on Startup) ---

MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN"

print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---")

# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU detected. Using CUDA.")
else:
    device = torch.device("cpu")
    print("No GPU detected. Using CPU.")

dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"--- Using dtype: {dtype} ---")

print(f"--- Loading tokenizer from Hub: {MODEL_NAME} ---")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True
    )
    print("--- Tokenizer Loaded Successfully ---")
except Exception as e:
    raise RuntimeError(f"FATAL: Could not load tokenizer from the Hub. Error: {e}")

print(f"--- Loading Model with PyTorch from Hub: {MODEL_NAME} ---")
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=dtype,
        trust_remote_code=True
    ).to(device)
    model.eval()
    print("--- Model Loaded Successfully ---")
except Exception as e:
    raise RuntimeError(f"FATAL: Could not load model from the Hub. Error: {e}")


# --- Helper function for chunking text (Unchanged) ---
def chunk_text(text: str, max_size: int) -> list[str]:
    """Splits text into chunks, trying to break at sentence endings."""
    if not text: return []
    chunks, start_index = [], 0
    while start_index < len(text):
        end_index = start_index + max_size
        if end_index >= len(text):
            chunks.append(text[start_index:])
            break
        split_pos = text.rfind('.', start_index, end_index)
        if split_pos != -1:
            chunk, start_index = text[start_index : split_pos + 1], split_pos + 1
        else:
            chunk, start_index = text[start_index:end_index], end_index
        chunks.append(chunk.strip())
    return [c for c in chunks if c]


# --- Simplified translation helper for internal use ---
def do_translation(text_to_translate: str) -> str:
    """A clean helper function to run a single translation."""
    if not text_to_translate.strip():
        return ""
    messages = [{"role": "user", "content": text_to_translate}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=2048,
        do_sample=True, temperature=0.7, top_p=0.95, top_k=50
    )
    
    input_token_len = model_inputs.input_ids.shape[1]
    output_ids = generated_ids[0][input_token_len:].tolist()
    return tokenizer.decode(output_ids, skip_special_tokens=True).strip()


# --- Step 3: Core Translation Function (USING ROBUST 'DIFF' ALGORITHM) ---
@spaces.GPU
@torch.no_grad()
def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str:
    """
    Processes text by chunks, using a robust word-by-word 'diff' algorithm
    to reliably find and remove the overlapping translation.
    """
    progress(0, desc="Starting...")
    print("--- Inference with robust 'diff' context method started ---")
    if not input_text or not input_text.strip():
        return "Input text is empty. Please enter some text to translate."

    progress(0.1, desc="Chunking Text...")
    text_chunks = chunk_text(input_text, chunk_size) if len(input_text) > chunk_size else [input_text]
    num_chunks = len(text_chunks)
    print(f"Processing {num_chunks} chunk(s).")

    all_results = []
    english_context = ""

    for i, chunk in enumerate(text_chunks):
        progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}")

        prompt_with_context = (english_context + " " + chunk).strip()
        full_translation = do_translation(prompt_with_context)

        final_translation_for_chunk = full_translation
        if english_context:
            translated_context = do_translation(english_context)
            
            # --- Start of the Diff Algorithm ---
            context_words_list = translated_context.split()
            full_translation_words_list = full_translation.split()
            
            # Find the first point of difference
            overlap_len_in_words = 0
            for i in range(min(len(context_words_list), len(full_translation_words_list))):
                # Compare words robustly (lowercase, strip punctuation)
                if context_words_list[i].strip('.,!?;:').lower() != full_translation_words_list[i].strip('.,!?;:').lower():
                    break
                overlap_len_in_words += 1
            
            # The new text starts after the matching words
            final_translation_for_chunk = " ".join(full_translation_words_list[overlap_len_in_words:])
            # --- End of the Diff Algorithm ---
            
        all_results.append(final_translation_for_chunk)
        print(f"Chunk {i+1} processed successfully.")

        if context_words > 0:
            words = chunk.split()
            english_context = " ".join(words[-context_words:])

    progress(0.95, desc="Reassembling Results...")
    full_output = " ".join(all_results)

    progress(1.0, desc="Done!")
    return full_output

# --- Step 4: Create and Launch the Gradio App ---
print("\n--- Initializing Gradio Interface ---")

app = gr.Interface(
    fn=translate_with_chunks,
    inputs=[
        gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."),
        gr.Slider(minimum=128, maximum=1536, value=1024, step=64, label="Character Chunk Size"),
        gr.Slider(
            minimum=0,
            maximum=50,
            value=15,
            step=5,
            label="Context Overlap (Source Words)",
            info="Number of English words from the end of the previous chunk to provide as context for the next one. Ensures consistency."
        )
    ],
    outputs=gr.Textbox(lines=15, label="Model Output", interactive=False),
    title="ERNIE 4.5 Context-Aware Translator",
    description="Processes long text using a robust 'diff' algorithm to ensure high-quality, consistent translations without duplication.",
    allow_flagging="never"
)

if __name__ == "__main__":
    app.queue().launch()