File size: 7,111 Bytes
a439924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# --- START OF FILE app.py ---

import sys
import os
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from dotenv import load_dotenv

# --- FIX: Add project root to Python's path ---
project_root = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, project_root)

# --- Updated Spaces import for Zero-GPU compatibility ---
try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces()

# --- Step 1: Hugging Face Authentication ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("FATAL: Hugging Face token not found. Please set the HF_TOKEN environment variable.")
print("--- Logging in to Hugging Face Hub ---")
login(token=HF_TOKEN)

# --- Step 2: Initialize Model and Tokenizer ---
MODEL_NAME = "Gregniuki/ERNIE-4.5-0.3B-PT-Translator-EN-PL-EN"
print(f"--- Loading model from Hugging Face Hub: {MODEL_NAME} ---")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"--- Using device: {device}, dtype: {dtype} ---")

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, trust_remote_code=True).to(device)
    model.eval()
    print("--- Model and Tokenizer Loaded Successfully ---")
except Exception as e:
    raise RuntimeError(f"FATAL: Could not load components. Error: {e}")

# --- Helper Functions ---
def chunk_text(text: str, max_size: int) -> list[str]:
    if not text: return []
    chunks, start_index = [], 0
    while start_index < len(text):
        end_index = start_index + max_size
        if end_index >= len(text):
            chunks.append(text[start_index:]); break
        split_pos = text.rfind('.', start_index, end_index)
        if split_pos != -1:
            chunk, start_index = text[start_index : split_pos + 1], split_pos + 1
        else:
            chunk, start_index = text[start_index:end_index], end_index
        chunks.append(chunk.strip())
    return [c for c in chunks if c]

def do_translation_get_ids(text_to_translate: str) -> list[int]:
    """Runs a single translation and returns ONLY the raw output token IDs."""
    if not text_to_translate.strip(): return []
    messages = [{"role": "user", "content": text_to_translate}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([prompt], add_special_tokens=False, return_tensors="pt").to(device)
    
    generated_ids_tensor = model.generate(**model_inputs, max_new_tokens=2048, do_sample=True, temperature=0.7, top_p=0.95, top_k=50)
    
    input_token_len = model_inputs.input_ids.shape[1]
    return generated_ids_tensor[0][input_token_len:].tolist()

def preprocess_text(text: str) -> str:
    """Intelligently cleans text by handling newlines."""
    if not text: return ""
    text = re.sub(r'\n{2,}', ' ', text)
    text = text.replace('\n', ' ')
    text = re.sub(r'\s{2,}', ' ', text)
    return text.strip()

# --- Step 3: Core Translation Function (DEFINITIVE TOKEN-LEVEL DIFF) ---
@spaces.GPU
@torch.no_grad()
def translate_with_chunks(input_text: str, chunk_size: int, context_words: int, progress=gr.Progress()) -> str:
    """
    Processes chunks using a precise token-level diff to remove overlap.
    This is the most robust method for this model.
    """
    progress(0, desc="Starting...")
    processed_text = preprocess_text(input_text)
    if not processed_text: return "Input text is empty. Please enter some text to translate."

    text_chunks = chunk_text(processed_text, chunk_size) if len(processed_text) > chunk_size else [processed_text]
    num_chunks = len(text_chunks)
    print(f"Processing {num_chunks} chunk(s).")

    all_results = []
    english_context = ""

    for i, chunk in enumerate(text_chunks):
        progress(0.2 + (i / num_chunks) * 0.7, desc=f"Translating chunk {i+1}/{num_chunks}")

        if not english_context or context_words == 0:
            # First chunk or context disabled: Translate directly and decode
            output_ids = do_translation_get_ids(chunk)
            final_translation_for_chunk = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
        else:
            # --- The Token-Level Diff Logic ---
            prompt_with_context = (english_context + " " + chunk).strip()
            
            # 1. Get token IDs for the context translation
            context_ids = do_translation_get_ids(english_context)
            
            # 2. Get token IDs for the full translation
            full_ids = do_translation_get_ids(prompt_with_context)

            # 3. Find the first point of difference at the token level
            diff_index = 0
            for j in range(min(len(context_ids), len(full_ids))):
                if context_ids[j] != full_ids[j]:
                    break
                diff_index += 1
            
            # 4. The clean translation starts from the point of difference
            clean_ids = full_ids[diff_index:]
            final_translation_for_chunk = tokenizer.decode(clean_ids, skip_special_tokens=True).strip()
            
        all_results.append(final_translation_for_chunk)
        print(f"Chunk {i+1} processed successfully.")

        if context_words > 0:
            # Update context with words from the *source* English chunk
            words = chunk.split()
            english_context = " ".join(words[-context_words:])

    full_output = " ".join(all_results)
    progress(1.0, desc="Done!")
    return full_output

# --- Step 4: Create and Launch the Gradio App ---
print("\n--- Initializing Gradio Interface ---")
app = gr.Interface(
    fn=translate_with_chunks,
    inputs=[
        gr.Textbox(lines=15, label="Input Text", placeholder="Enter long text to process here..."),
        gr.Slider(minimum=256, maximum=2048, value=512, step=64, label="Character Chunk Size"),
        gr.Slider(
            minimum=0,
            maximum=50,
            value=20,
            step=5,
            label="Context Overlap (English Words)",
            info="Number of English words from the previous chunk to use as context. A token-level comparison is used to reliably remove the overlap."
        )
    ],
    outputs=gr.Textbox(lines=15, label="Model Output", interactive=False),
    title="ERNIE 4.5 Context-Aware Translator",
    description="Processes long text using a precise, token-level diffing algorithm to ensure high-quality, consistent translations.",
    allow_flagging="never"
)

if __name__ == "__main__":
    app.queue().launch()